_plot.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907
  1. import warnings
  2. import numpy as np
  3. from ..utils import check_matplotlib_support
  4. from ..utils._plotting import _interval_max_min_ratio, _validate_score_name
  5. from ._validation import learning_curve, validation_curve
  6. class _BaseCurveDisplay:
  7. def _plot_curve(
  8. self,
  9. x_data,
  10. *,
  11. ax=None,
  12. negate_score=False,
  13. score_name=None,
  14. score_type="test",
  15. log_scale="deprecated",
  16. std_display_style="fill_between",
  17. line_kw=None,
  18. fill_between_kw=None,
  19. errorbar_kw=None,
  20. ):
  21. check_matplotlib_support(f"{self.__class__.__name__}.plot")
  22. import matplotlib.pyplot as plt
  23. if ax is None:
  24. _, ax = plt.subplots()
  25. if negate_score:
  26. train_scores, test_scores = -self.train_scores, -self.test_scores
  27. else:
  28. train_scores, test_scores = self.train_scores, self.test_scores
  29. if std_display_style not in ("errorbar", "fill_between", None):
  30. raise ValueError(
  31. f"Unknown std_display_style: {std_display_style}. Should be one of"
  32. " 'errorbar', 'fill_between', or None."
  33. )
  34. if score_type not in ("test", "train", "both"):
  35. raise ValueError(
  36. f"Unknown score_type: {score_type}. Should be one of 'test', "
  37. "'train', or 'both'."
  38. )
  39. if score_type == "train":
  40. scores = {"Train": train_scores}
  41. elif score_type == "test":
  42. scores = {"Test": test_scores}
  43. else: # score_type == "both"
  44. scores = {"Train": train_scores, "Test": test_scores}
  45. if std_display_style in ("fill_between", None):
  46. # plot the mean score
  47. if line_kw is None:
  48. line_kw = {}
  49. self.lines_ = []
  50. for line_label, score in scores.items():
  51. self.lines_.append(
  52. *ax.plot(
  53. x_data,
  54. score.mean(axis=1),
  55. label=line_label,
  56. **line_kw,
  57. )
  58. )
  59. self.errorbar_ = None
  60. self.fill_between_ = None # overwritten below by fill_between
  61. if std_display_style == "errorbar":
  62. if errorbar_kw is None:
  63. errorbar_kw = {}
  64. self.errorbar_ = []
  65. for line_label, score in scores.items():
  66. self.errorbar_.append(
  67. ax.errorbar(
  68. x_data,
  69. score.mean(axis=1),
  70. score.std(axis=1),
  71. label=line_label,
  72. **errorbar_kw,
  73. )
  74. )
  75. self.lines_, self.fill_between_ = None, None
  76. elif std_display_style == "fill_between":
  77. if fill_between_kw is None:
  78. fill_between_kw = {}
  79. default_fill_between_kw = {"alpha": 0.5}
  80. fill_between_kw = {**default_fill_between_kw, **fill_between_kw}
  81. self.fill_between_ = []
  82. for line_label, score in scores.items():
  83. self.fill_between_.append(
  84. ax.fill_between(
  85. x_data,
  86. score.mean(axis=1) - score.std(axis=1),
  87. score.mean(axis=1) + score.std(axis=1),
  88. **fill_between_kw,
  89. )
  90. )
  91. score_name = self.score_name if score_name is None else score_name
  92. ax.legend()
  93. # TODO(1.5): to be removed
  94. if log_scale != "deprecated":
  95. warnings.warn(
  96. (
  97. "The `log_scale` parameter is deprecated as of version 1.3 "
  98. "and will be removed in 1.5. You can use display.ax_.set_xscale "
  99. "and display.ax_.set_yscale instead."
  100. ),
  101. FutureWarning,
  102. )
  103. xscale = "log" if log_scale else "linear"
  104. else:
  105. # We found that a ratio, smaller or bigger than 5, between the largest and
  106. # smallest gap of the x values is a good indicator to choose between linear
  107. # and log scale.
  108. if _interval_max_min_ratio(x_data) > 5:
  109. xscale = "symlog" if x_data.min() <= 0 else "log"
  110. else:
  111. xscale = "linear"
  112. ax.set_xscale(xscale)
  113. ax.set_ylabel(f"{score_name}")
  114. self.ax_ = ax
  115. self.figure_ = ax.figure
  116. class LearningCurveDisplay(_BaseCurveDisplay):
  117. """Learning Curve visualization.
  118. It is recommended to use
  119. :meth:`~sklearn.model_selection.LearningCurveDisplay.from_estimator` to
  120. create a :class:`~sklearn.model_selection.LearningCurveDisplay` instance.
  121. All parameters are stored as attributes.
  122. Read more in the :ref:`User Guide <visualizations>` for general information
  123. about the visualization API and
  124. :ref:`detailed documentation <learning_curve>` regarding the learning
  125. curve visualization.
  126. .. versionadded:: 1.2
  127. Parameters
  128. ----------
  129. train_sizes : ndarray of shape (n_unique_ticks,)
  130. Numbers of training examples that has been used to generate the
  131. learning curve.
  132. train_scores : ndarray of shape (n_ticks, n_cv_folds)
  133. Scores on training sets.
  134. test_scores : ndarray of shape (n_ticks, n_cv_folds)
  135. Scores on test set.
  136. score_name : str, default=None
  137. The name of the score used in `learning_curve`. It will override the name
  138. inferred from the `scoring` parameter. If `score` is `None`, we use `"Score"` if
  139. `negate_score` is `False` and `"Negative score"` otherwise. If `scoring` is a
  140. string or a callable, we infer the name. We replace `_` by spaces and capitalize
  141. the first letter. We remove `neg_` and replace it by `"Negative"` if
  142. `negate_score` is `False` or just remove it otherwise.
  143. Attributes
  144. ----------
  145. ax_ : matplotlib Axes
  146. Axes with the learning curve.
  147. figure_ : matplotlib Figure
  148. Figure containing the learning curve.
  149. errorbar_ : list of matplotlib Artist or None
  150. When the `std_display_style` is `"errorbar"`, this is a list of
  151. `matplotlib.container.ErrorbarContainer` objects. If another style is
  152. used, `errorbar_` is `None`.
  153. lines_ : list of matplotlib Artist or None
  154. When the `std_display_style` is `"fill_between"`, this is a list of
  155. `matplotlib.lines.Line2D` objects corresponding to the mean train and
  156. test scores. If another style is used, `line_` is `None`.
  157. fill_between_ : list of matplotlib Artist or None
  158. When the `std_display_style` is `"fill_between"`, this is a list of
  159. `matplotlib.collections.PolyCollection` objects. If another style is
  160. used, `fill_between_` is `None`.
  161. See Also
  162. --------
  163. sklearn.model_selection.learning_curve : Compute the learning curve.
  164. Examples
  165. --------
  166. >>> import matplotlib.pyplot as plt
  167. >>> from sklearn.datasets import load_iris
  168. >>> from sklearn.model_selection import LearningCurveDisplay, learning_curve
  169. >>> from sklearn.tree import DecisionTreeClassifier
  170. >>> X, y = load_iris(return_X_y=True)
  171. >>> tree = DecisionTreeClassifier(random_state=0)
  172. >>> train_sizes, train_scores, test_scores = learning_curve(
  173. ... tree, X, y)
  174. >>> display = LearningCurveDisplay(train_sizes=train_sizes,
  175. ... train_scores=train_scores, test_scores=test_scores, score_name="Score")
  176. >>> display.plot()
  177. <...>
  178. >>> plt.show()
  179. """
  180. def __init__(self, *, train_sizes, train_scores, test_scores, score_name=None):
  181. self.train_sizes = train_sizes
  182. self.train_scores = train_scores
  183. self.test_scores = test_scores
  184. self.score_name = score_name
  185. def plot(
  186. self,
  187. ax=None,
  188. *,
  189. negate_score=False,
  190. score_name=None,
  191. score_type="both",
  192. log_scale="deprecated",
  193. std_display_style="fill_between",
  194. line_kw=None,
  195. fill_between_kw=None,
  196. errorbar_kw=None,
  197. ):
  198. """Plot visualization.
  199. Parameters
  200. ----------
  201. ax : matplotlib Axes, default=None
  202. Axes object to plot on. If `None`, a new figure and axes is
  203. created.
  204. negate_score : bool, default=False
  205. Whether or not to negate the scores obtained through
  206. :func:`~sklearn.model_selection.learning_curve`. This is
  207. particularly useful when using the error denoted by `neg_*` in
  208. `scikit-learn`.
  209. score_name : str, default=None
  210. The name of the score used to decorate the y-axis of the plot. It will
  211. override the name inferred from the `scoring` parameter. If `score` is
  212. `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
  213. otherwise. If `scoring` is a string or a callable, we infer the name. We
  214. replace `_` by spaces and capitalize the first letter. We remove `neg_` and
  215. replace it by `"Negative"` if `negate_score` is
  216. `False` or just remove it otherwise.
  217. score_type : {"test", "train", "both"}, default="both"
  218. The type of score to plot. Can be one of `"test"`, `"train"`, or
  219. `"both"`.
  220. log_scale : bool, default="deprecated"
  221. Whether or not to use a logarithmic scale for the x-axis.
  222. .. deprecated:: 1.3
  223. `log_scale` is deprecated in 1.3 and will be removed in 1.5.
  224. Use `display.ax_.set_xscale` and `display.ax_.set_yscale` instead.
  225. std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
  226. The style used to display the score standard deviation around the
  227. mean score. If None, no standard deviation representation is
  228. displayed.
  229. line_kw : dict, default=None
  230. Additional keyword arguments passed to the `plt.plot` used to draw
  231. the mean score.
  232. fill_between_kw : dict, default=None
  233. Additional keyword arguments passed to the `plt.fill_between` used
  234. to draw the score standard deviation.
  235. errorbar_kw : dict, default=None
  236. Additional keyword arguments passed to the `plt.errorbar` used to
  237. draw mean score and standard deviation score.
  238. Returns
  239. -------
  240. display : :class:`~sklearn.model_selection.LearningCurveDisplay`
  241. Object that stores computed values.
  242. """
  243. self._plot_curve(
  244. self.train_sizes,
  245. ax=ax,
  246. negate_score=negate_score,
  247. score_name=score_name,
  248. score_type=score_type,
  249. log_scale=log_scale,
  250. std_display_style=std_display_style,
  251. line_kw=line_kw,
  252. fill_between_kw=fill_between_kw,
  253. errorbar_kw=errorbar_kw,
  254. )
  255. self.ax_.set_xlabel("Number of samples in the training set")
  256. return self
  257. @classmethod
  258. def from_estimator(
  259. cls,
  260. estimator,
  261. X,
  262. y,
  263. *,
  264. groups=None,
  265. train_sizes=np.linspace(0.1, 1.0, 5),
  266. cv=None,
  267. scoring=None,
  268. exploit_incremental_learning=False,
  269. n_jobs=None,
  270. pre_dispatch="all",
  271. verbose=0,
  272. shuffle=False,
  273. random_state=None,
  274. error_score=np.nan,
  275. fit_params=None,
  276. ax=None,
  277. negate_score=False,
  278. score_name=None,
  279. score_type="both",
  280. log_scale="deprecated",
  281. std_display_style="fill_between",
  282. line_kw=None,
  283. fill_between_kw=None,
  284. errorbar_kw=None,
  285. ):
  286. """Create a learning curve display from an estimator.
  287. Read more in the :ref:`User Guide <visualizations>` for general
  288. information about the visualization API and :ref:`detailed
  289. documentation <learning_curve>` regarding the learning curve
  290. visualization.
  291. Parameters
  292. ----------
  293. estimator : object type that implements the "fit" and "predict" methods
  294. An object of that type which is cloned for each validation.
  295. X : array-like of shape (n_samples, n_features)
  296. Training data, where `n_samples` is the number of samples and
  297. `n_features` is the number of features.
  298. y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
  299. Target relative to X for classification or regression;
  300. None for unsupervised learning.
  301. groups : array-like of shape (n_samples,), default=None
  302. Group labels for the samples used while splitting the dataset into
  303. train/test set. Only used in conjunction with a "Group" :term:`cv`
  304. instance (e.g., :class:`GroupKFold`).
  305. train_sizes : array-like of shape (n_ticks,), \
  306. default=np.linspace(0.1, 1.0, 5)
  307. Relative or absolute numbers of training examples that will be used
  308. to generate the learning curve. If the dtype is float, it is
  309. regarded as a fraction of the maximum size of the training set
  310. (that is determined by the selected validation method), i.e. it has
  311. to be within (0, 1]. Otherwise it is interpreted as absolute sizes
  312. of the training sets. Note that for classification the number of
  313. samples usually have to be big enough to contain at least one
  314. sample from each class.
  315. cv : int, cross-validation generator or an iterable, default=None
  316. Determines the cross-validation splitting strategy.
  317. Possible inputs for cv are:
  318. - None, to use the default 5-fold cross validation,
  319. - int, to specify the number of folds in a `(Stratified)KFold`,
  320. - :term:`CV splitter`,
  321. - An iterable yielding (train, test) splits as arrays of indices.
  322. For int/None inputs, if the estimator is a classifier and `y` is
  323. either binary or multiclass,
  324. :class:`~sklearn.model_selection.StratifiedKFold` is used. In all
  325. other cases, :class:`~sklearn.model_selection.KFold` is used. These
  326. splitters are instantiated with `shuffle=False` so the splits will
  327. be the same across calls.
  328. Refer :ref:`User Guide <cross_validation>` for the various
  329. cross-validation strategies that can be used here.
  330. scoring : str or callable, default=None
  331. A string (see :ref:`scoring_parameter`) or
  332. a scorer callable object / function with signature
  333. `scorer(estimator, X, y)` (see :ref:`scoring`).
  334. exploit_incremental_learning : bool, default=False
  335. If the estimator supports incremental learning, this will be
  336. used to speed up fitting for different training set sizes.
  337. n_jobs : int, default=None
  338. Number of jobs to run in parallel. Training the estimator and
  339. computing the score are parallelized over the different training
  340. and test sets. `None` means 1 unless in a
  341. :obj:`joblib.parallel_backend` context. `-1` means using all
  342. processors. See :term:`Glossary <n_jobs>` for more details.
  343. pre_dispatch : int or str, default='all'
  344. Number of predispatched jobs for parallel execution (default is
  345. all). The option can reduce the allocated memory. The str can
  346. be an expression like '2*n_jobs'.
  347. verbose : int, default=0
  348. Controls the verbosity: the higher, the more messages.
  349. shuffle : bool, default=False
  350. Whether to shuffle training data before taking prefixes of it
  351. based on`train_sizes`.
  352. random_state : int, RandomState instance or None, default=None
  353. Used when `shuffle` is True. Pass an int for reproducible
  354. output across multiple function calls.
  355. See :term:`Glossary <random_state>`.
  356. error_score : 'raise' or numeric, default=np.nan
  357. Value to assign to the score if an error occurs in estimator
  358. fitting. If set to 'raise', the error is raised. If a numeric value
  359. is given, FitFailedWarning is raised.
  360. fit_params : dict, default=None
  361. Parameters to pass to the fit method of the estimator.
  362. ax : matplotlib Axes, default=None
  363. Axes object to plot on. If `None`, a new figure and axes is
  364. created.
  365. negate_score : bool, default=False
  366. Whether or not to negate the scores obtained through
  367. :func:`~sklearn.model_selection.learning_curve`. This is
  368. particularly useful when using the error denoted by `neg_*` in
  369. `scikit-learn`.
  370. score_name : str, default=None
  371. The name of the score used to decorate the y-axis of the plot. It will
  372. override the name inferred from the `scoring` parameter. If `score` is
  373. `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
  374. otherwise. If `scoring` is a string or a callable, we infer the name. We
  375. replace `_` by spaces and capitalize the first letter. We remove `neg_` and
  376. replace it by `"Negative"` if `negate_score` is
  377. `False` or just remove it otherwise.
  378. score_type : {"test", "train", "both"}, default="both"
  379. The type of score to plot. Can be one of `"test"`, `"train"`, or
  380. `"both"`.
  381. log_scale : bool, default="deprecated"
  382. Whether or not to use a logarithmic scale for the x-axis.
  383. .. deprecated:: 1.3
  384. `log_scale` is deprecated in 1.3 and will be removed in 1.5.
  385. Use `display.ax_.xscale` and `display.ax_.yscale` instead.
  386. std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
  387. The style used to display the score standard deviation around the
  388. mean score. If `None`, no representation of the standard deviation
  389. is displayed.
  390. line_kw : dict, default=None
  391. Additional keyword arguments passed to the `plt.plot` used to draw
  392. the mean score.
  393. fill_between_kw : dict, default=None
  394. Additional keyword arguments passed to the `plt.fill_between` used
  395. to draw the score standard deviation.
  396. errorbar_kw : dict, default=None
  397. Additional keyword arguments passed to the `plt.errorbar` used to
  398. draw mean score and standard deviation score.
  399. Returns
  400. -------
  401. display : :class:`~sklearn.model_selection.LearningCurveDisplay`
  402. Object that stores computed values.
  403. Examples
  404. --------
  405. >>> import matplotlib.pyplot as plt
  406. >>> from sklearn.datasets import load_iris
  407. >>> from sklearn.model_selection import LearningCurveDisplay
  408. >>> from sklearn.tree import DecisionTreeClassifier
  409. >>> X, y = load_iris(return_X_y=True)
  410. >>> tree = DecisionTreeClassifier(random_state=0)
  411. >>> LearningCurveDisplay.from_estimator(tree, X, y)
  412. <...>
  413. >>> plt.show()
  414. """
  415. check_matplotlib_support(f"{cls.__name__}.from_estimator")
  416. score_name = _validate_score_name(score_name, scoring, negate_score)
  417. train_sizes, train_scores, test_scores = learning_curve(
  418. estimator,
  419. X,
  420. y,
  421. groups=groups,
  422. train_sizes=train_sizes,
  423. cv=cv,
  424. scoring=scoring,
  425. exploit_incremental_learning=exploit_incremental_learning,
  426. n_jobs=n_jobs,
  427. pre_dispatch=pre_dispatch,
  428. verbose=verbose,
  429. shuffle=shuffle,
  430. random_state=random_state,
  431. error_score=error_score,
  432. return_times=False,
  433. fit_params=fit_params,
  434. )
  435. viz = cls(
  436. train_sizes=train_sizes,
  437. train_scores=train_scores,
  438. test_scores=test_scores,
  439. score_name=score_name,
  440. )
  441. return viz.plot(
  442. ax=ax,
  443. negate_score=negate_score,
  444. score_type=score_type,
  445. log_scale=log_scale,
  446. std_display_style=std_display_style,
  447. line_kw=line_kw,
  448. fill_between_kw=fill_between_kw,
  449. errorbar_kw=errorbar_kw,
  450. )
  451. class ValidationCurveDisplay(_BaseCurveDisplay):
  452. """Validation Curve visualization.
  453. It is recommended to use
  454. :meth:`~sklearn.model_selection.ValidationCurveDisplay.from_estimator` to
  455. create a :class:`~sklearn.model_selection.ValidationCurveDisplay` instance.
  456. All parameters are stored as attributes.
  457. Read more in the :ref:`User Guide <visualizations>` for general information
  458. about the visualization API and :ref:`detailed documentation
  459. <validation_curve>` regarding the validation curve visualization.
  460. .. versionadded:: 1.3
  461. Parameters
  462. ----------
  463. param_name : str
  464. Name of the parameter that has been varied.
  465. param_range : array-like of shape (n_ticks,)
  466. The values of the parameter that have been evaluated.
  467. train_scores : ndarray of shape (n_ticks, n_cv_folds)
  468. Scores on training sets.
  469. test_scores : ndarray of shape (n_ticks, n_cv_folds)
  470. Scores on test set.
  471. score_name : str, default=None
  472. The name of the score used in `validation_curve`. It will override the name
  473. inferred from the `scoring` parameter. If `score` is `None`, we use `"Score"` if
  474. `negate_score` is `False` and `"Negative score"` otherwise. If `scoring` is a
  475. string or a callable, we infer the name. We replace `_` by spaces and capitalize
  476. the first letter. We remove `neg_` and replace it by `"Negative"` if
  477. `negate_score` is `False` or just remove it otherwise.
  478. Attributes
  479. ----------
  480. ax_ : matplotlib Axes
  481. Axes with the validation curve.
  482. figure_ : matplotlib Figure
  483. Figure containing the validation curve.
  484. errorbar_ : list of matplotlib Artist or None
  485. When the `std_display_style` is `"errorbar"`, this is a list of
  486. `matplotlib.container.ErrorbarContainer` objects. If another style is
  487. used, `errorbar_` is `None`.
  488. lines_ : list of matplotlib Artist or None
  489. When the `std_display_style` is `"fill_between"`, this is a list of
  490. `matplotlib.lines.Line2D` objects corresponding to the mean train and
  491. test scores. If another style is used, `line_` is `None`.
  492. fill_between_ : list of matplotlib Artist or None
  493. When the `std_display_style` is `"fill_between"`, this is a list of
  494. `matplotlib.collections.PolyCollection` objects. If another style is
  495. used, `fill_between_` is `None`.
  496. See Also
  497. --------
  498. sklearn.model_selection.validation_curve : Compute the validation curve.
  499. Examples
  500. --------
  501. >>> import numpy as np
  502. >>> import matplotlib.pyplot as plt
  503. >>> from sklearn.datasets import make_classification
  504. >>> from sklearn.model_selection import ValidationCurveDisplay, validation_curve
  505. >>> from sklearn.linear_model import LogisticRegression
  506. >>> X, y = make_classification(n_samples=1_000, random_state=0)
  507. >>> logistic_regression = LogisticRegression()
  508. >>> param_name, param_range = "C", np.logspace(-8, 3, 10)
  509. >>> train_scores, test_scores = validation_curve(
  510. ... logistic_regression, X, y, param_name=param_name, param_range=param_range
  511. ... )
  512. >>> display = ValidationCurveDisplay(
  513. ... param_name=param_name, param_range=param_range,
  514. ... train_scores=train_scores, test_scores=test_scores, score_name="Score"
  515. ... )
  516. >>> display.plot()
  517. <...>
  518. >>> plt.show()
  519. """
  520. def __init__(
  521. self, *, param_name, param_range, train_scores, test_scores, score_name=None
  522. ):
  523. self.param_name = param_name
  524. self.param_range = param_range
  525. self.train_scores = train_scores
  526. self.test_scores = test_scores
  527. self.score_name = score_name
  528. def plot(
  529. self,
  530. ax=None,
  531. *,
  532. negate_score=False,
  533. score_name=None,
  534. score_type="both",
  535. std_display_style="fill_between",
  536. line_kw=None,
  537. fill_between_kw=None,
  538. errorbar_kw=None,
  539. ):
  540. """Plot visualization.
  541. Parameters
  542. ----------
  543. ax : matplotlib Axes, default=None
  544. Axes object to plot on. If `None`, a new figure and axes is
  545. created.
  546. negate_score : bool, default=False
  547. Whether or not to negate the scores obtained through
  548. :func:`~sklearn.model_selection.validation_curve`. This is
  549. particularly useful when using the error denoted by `neg_*` in
  550. `scikit-learn`.
  551. score_name : str, default=None
  552. The name of the score used to decorate the y-axis of the plot. It will
  553. override the name inferred from the `scoring` parameter. If `score` is
  554. `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
  555. otherwise. If `scoring` is a string or a callable, we infer the name. We
  556. replace `_` by spaces and capitalize the first letter. We remove `neg_` and
  557. replace it by `"Negative"` if `negate_score` is
  558. `False` or just remove it otherwise.
  559. score_type : {"test", "train", "both"}, default="both"
  560. The type of score to plot. Can be one of `"test"`, `"train"`, or
  561. `"both"`.
  562. std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
  563. The style used to display the score standard deviation around the
  564. mean score. If None, no standard deviation representation is
  565. displayed.
  566. line_kw : dict, default=None
  567. Additional keyword arguments passed to the `plt.plot` used to draw
  568. the mean score.
  569. fill_between_kw : dict, default=None
  570. Additional keyword arguments passed to the `plt.fill_between` used
  571. to draw the score standard deviation.
  572. errorbar_kw : dict, default=None
  573. Additional keyword arguments passed to the `plt.errorbar` used to
  574. draw mean score and standard deviation score.
  575. Returns
  576. -------
  577. display : :class:`~sklearn.model_selection.ValidationCurveDisplay`
  578. Object that stores computed values.
  579. """
  580. self._plot_curve(
  581. self.param_range,
  582. ax=ax,
  583. negate_score=negate_score,
  584. score_name=score_name,
  585. score_type=score_type,
  586. log_scale="deprecated",
  587. std_display_style=std_display_style,
  588. line_kw=line_kw,
  589. fill_between_kw=fill_between_kw,
  590. errorbar_kw=errorbar_kw,
  591. )
  592. self.ax_.set_xlabel(f"{self.param_name}")
  593. return self
  594. @classmethod
  595. def from_estimator(
  596. cls,
  597. estimator,
  598. X,
  599. y,
  600. *,
  601. param_name,
  602. param_range,
  603. groups=None,
  604. cv=None,
  605. scoring=None,
  606. n_jobs=None,
  607. pre_dispatch="all",
  608. verbose=0,
  609. error_score=np.nan,
  610. fit_params=None,
  611. ax=None,
  612. negate_score=False,
  613. score_name=None,
  614. score_type="both",
  615. std_display_style="fill_between",
  616. line_kw=None,
  617. fill_between_kw=None,
  618. errorbar_kw=None,
  619. ):
  620. """Create a validation curve display from an estimator.
  621. Read more in the :ref:`User Guide <visualizations>` for general
  622. information about the visualization API and :ref:`detailed
  623. documentation <validation_curve>` regarding the validation curve
  624. visualization.
  625. Parameters
  626. ----------
  627. estimator : object type that implements the "fit" and "predict" methods
  628. An object of that type which is cloned for each validation.
  629. X : array-like of shape (n_samples, n_features)
  630. Training data, where `n_samples` is the number of samples and
  631. `n_features` is the number of features.
  632. y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
  633. Target relative to X for classification or regression;
  634. None for unsupervised learning.
  635. param_name : str
  636. Name of the parameter that will be varied.
  637. param_range : array-like of shape (n_values,)
  638. The values of the parameter that will be evaluated.
  639. groups : array-like of shape (n_samples,), default=None
  640. Group labels for the samples used while splitting the dataset into
  641. train/test set. Only used in conjunction with a "Group" :term:`cv`
  642. instance (e.g., :class:`GroupKFold`).
  643. cv : int, cross-validation generator or an iterable, default=None
  644. Determines the cross-validation splitting strategy.
  645. Possible inputs for cv are:
  646. - None, to use the default 5-fold cross validation,
  647. - int, to specify the number of folds in a `(Stratified)KFold`,
  648. - :term:`CV splitter`,
  649. - An iterable yielding (train, test) splits as arrays of indices.
  650. For int/None inputs, if the estimator is a classifier and `y` is
  651. either binary or multiclass,
  652. :class:`~sklearn.model_selection.StratifiedKFold` is used. In all
  653. other cases, :class:`~sklearn.model_selection.KFold` is used. These
  654. splitters are instantiated with `shuffle=False` so the splits will
  655. be the same across calls.
  656. Refer :ref:`User Guide <cross_validation>` for the various
  657. cross-validation strategies that can be used here.
  658. scoring : str or callable, default=None
  659. A string (see :ref:`scoring_parameter`) or
  660. a scorer callable object / function with signature
  661. `scorer(estimator, X, y)` (see :ref:`scoring`).
  662. n_jobs : int, default=None
  663. Number of jobs to run in parallel. Training the estimator and
  664. computing the score are parallelized over the different training
  665. and test sets. `None` means 1 unless in a
  666. :obj:`joblib.parallel_backend` context. `-1` means using all
  667. processors. See :term:`Glossary <n_jobs>` for more details.
  668. pre_dispatch : int or str, default='all'
  669. Number of predispatched jobs for parallel execution (default is
  670. all). The option can reduce the allocated memory. The str can
  671. be an expression like '2*n_jobs'.
  672. verbose : int, default=0
  673. Controls the verbosity: the higher, the more messages.
  674. error_score : 'raise' or numeric, default=np.nan
  675. Value to assign to the score if an error occurs in estimator
  676. fitting. If set to 'raise', the error is raised. If a numeric value
  677. is given, FitFailedWarning is raised.
  678. fit_params : dict, default=None
  679. Parameters to pass to the fit method of the estimator.
  680. ax : matplotlib Axes, default=None
  681. Axes object to plot on. If `None`, a new figure and axes is
  682. created.
  683. negate_score : bool, default=False
  684. Whether or not to negate the scores obtained through
  685. :func:`~sklearn.model_selection.validation_curve`. This is
  686. particularly useful when using the error denoted by `neg_*` in
  687. `scikit-learn`.
  688. score_name : str, default=None
  689. The name of the score used to decorate the y-axis of the plot. It will
  690. override the name inferred from the `scoring` parameter. If `score` is
  691. `None`, we use `"Score"` if `negate_score` is `False` and `"Negative score"`
  692. otherwise. If `scoring` is a string or a callable, we infer the name. We
  693. replace `_` by spaces and capitalize the first letter. We remove `neg_` and
  694. replace it by `"Negative"` if `negate_score` is
  695. `False` or just remove it otherwise.
  696. score_type : {"test", "train", "both"}, default="both"
  697. The type of score to plot. Can be one of `"test"`, `"train"`, or
  698. `"both"`.
  699. std_display_style : {"errorbar", "fill_between"} or None, default="fill_between"
  700. The style used to display the score standard deviation around the
  701. mean score. If `None`, no representation of the standard deviation
  702. is displayed.
  703. line_kw : dict, default=None
  704. Additional keyword arguments passed to the `plt.plot` used to draw
  705. the mean score.
  706. fill_between_kw : dict, default=None
  707. Additional keyword arguments passed to the `plt.fill_between` used
  708. to draw the score standard deviation.
  709. errorbar_kw : dict, default=None
  710. Additional keyword arguments passed to the `plt.errorbar` used to
  711. draw mean score and standard deviation score.
  712. Returns
  713. -------
  714. display : :class:`~sklearn.model_selection.ValidationCurveDisplay`
  715. Object that stores computed values.
  716. Examples
  717. --------
  718. >>> import numpy as np
  719. >>> import matplotlib.pyplot as plt
  720. >>> from sklearn.datasets import make_classification
  721. >>> from sklearn.model_selection import ValidationCurveDisplay
  722. >>> from sklearn.linear_model import LogisticRegression
  723. >>> X, y = make_classification(n_samples=1_000, random_state=0)
  724. >>> logistic_regression = LogisticRegression()
  725. >>> param_name, param_range = "C", np.logspace(-8, 3, 10)
  726. >>> ValidationCurveDisplay.from_estimator(
  727. ... logistic_regression, X, y, param_name=param_name,
  728. ... param_range=param_range,
  729. ... )
  730. <...>
  731. >>> plt.show()
  732. """
  733. check_matplotlib_support(f"{cls.__name__}.from_estimator")
  734. score_name = _validate_score_name(score_name, scoring, negate_score)
  735. train_scores, test_scores = validation_curve(
  736. estimator,
  737. X,
  738. y,
  739. param_name=param_name,
  740. param_range=param_range,
  741. groups=groups,
  742. cv=cv,
  743. scoring=scoring,
  744. n_jobs=n_jobs,
  745. pre_dispatch=pre_dispatch,
  746. verbose=verbose,
  747. error_score=error_score,
  748. fit_params=fit_params,
  749. )
  750. viz = cls(
  751. param_name=param_name,
  752. param_range=np.array(param_range, copy=False),
  753. train_scores=train_scores,
  754. test_scores=test_scores,
  755. score_name=score_name,
  756. )
  757. return viz.plot(
  758. ax=ax,
  759. negate_score=negate_score,
  760. score_type=score_type,
  761. std_display_style=std_display_style,
  762. line_kw=line_kw,
  763. fill_between_kw=fill_between_kw,
  764. errorbar_kw=errorbar_kw,
  765. )