test_plot.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. import numpy as np
  2. import pytest
  3. from sklearn.datasets import load_iris
  4. from sklearn.model_selection import (
  5. LearningCurveDisplay,
  6. ValidationCurveDisplay,
  7. learning_curve,
  8. validation_curve,
  9. )
  10. from sklearn.tree import DecisionTreeClassifier
  11. from sklearn.utils import shuffle
  12. from sklearn.utils._testing import assert_allclose, assert_array_equal
  13. @pytest.fixture
  14. def data():
  15. return shuffle(*load_iris(return_X_y=True), random_state=0)
  16. @pytest.mark.parametrize(
  17. "params, err_type, err_msg",
  18. [
  19. ({"std_display_style": "invalid"}, ValueError, "Unknown std_display_style:"),
  20. ({"score_type": "invalid"}, ValueError, "Unknown score_type:"),
  21. ],
  22. )
  23. @pytest.mark.parametrize(
  24. "CurveDisplay, specific_params",
  25. [
  26. (ValidationCurveDisplay, {"param_name": "max_depth", "param_range": [1, 3, 5]}),
  27. (LearningCurveDisplay, {"train_sizes": [0.3, 0.6, 0.9]}),
  28. ],
  29. )
  30. def test_curve_display_parameters_validation(
  31. pyplot, data, params, err_type, err_msg, CurveDisplay, specific_params
  32. ):
  33. """Check that we raise a proper error when passing invalid parameters."""
  34. X, y = data
  35. estimator = DecisionTreeClassifier(random_state=0)
  36. with pytest.raises(err_type, match=err_msg):
  37. CurveDisplay.from_estimator(estimator, X, y, **specific_params, **params)
  38. def test_learning_curve_display_default_usage(pyplot, data):
  39. """Check the default usage of the LearningCurveDisplay class."""
  40. X, y = data
  41. estimator = DecisionTreeClassifier(random_state=0)
  42. train_sizes = [0.3, 0.6, 0.9]
  43. display = LearningCurveDisplay.from_estimator(
  44. estimator, X, y, train_sizes=train_sizes
  45. )
  46. import matplotlib as mpl
  47. assert display.errorbar_ is None
  48. assert isinstance(display.lines_, list)
  49. for line in display.lines_:
  50. assert isinstance(line, mpl.lines.Line2D)
  51. assert isinstance(display.fill_between_, list)
  52. for fill in display.fill_between_:
  53. assert isinstance(fill, mpl.collections.PolyCollection)
  54. assert fill.get_alpha() == 0.5
  55. assert display.score_name == "Score"
  56. assert display.ax_.get_xlabel() == "Number of samples in the training set"
  57. assert display.ax_.get_ylabel() == "Score"
  58. _, legend_labels = display.ax_.get_legend_handles_labels()
  59. assert legend_labels == ["Train", "Test"]
  60. train_sizes_abs, train_scores, test_scores = learning_curve(
  61. estimator, X, y, train_sizes=train_sizes
  62. )
  63. assert_array_equal(display.train_sizes, train_sizes_abs)
  64. assert_allclose(display.train_scores, train_scores)
  65. assert_allclose(display.test_scores, test_scores)
  66. def test_validation_curve_display_default_usage(pyplot, data):
  67. """Check the default usage of the ValidationCurveDisplay class."""
  68. X, y = data
  69. estimator = DecisionTreeClassifier(random_state=0)
  70. param_name, param_range = "max_depth", [1, 3, 5]
  71. display = ValidationCurveDisplay.from_estimator(
  72. estimator, X, y, param_name=param_name, param_range=param_range
  73. )
  74. import matplotlib as mpl
  75. assert display.errorbar_ is None
  76. assert isinstance(display.lines_, list)
  77. for line in display.lines_:
  78. assert isinstance(line, mpl.lines.Line2D)
  79. assert isinstance(display.fill_between_, list)
  80. for fill in display.fill_between_:
  81. assert isinstance(fill, mpl.collections.PolyCollection)
  82. assert fill.get_alpha() == 0.5
  83. assert display.score_name == "Score"
  84. assert display.ax_.get_xlabel() == f"{param_name}"
  85. assert display.ax_.get_ylabel() == "Score"
  86. _, legend_labels = display.ax_.get_legend_handles_labels()
  87. assert legend_labels == ["Train", "Test"]
  88. train_scores, test_scores = validation_curve(
  89. estimator, X, y, param_name=param_name, param_range=param_range
  90. )
  91. assert_array_equal(display.param_range, param_range)
  92. assert_allclose(display.train_scores, train_scores)
  93. assert_allclose(display.test_scores, test_scores)
  94. @pytest.mark.parametrize(
  95. "CurveDisplay, specific_params",
  96. [
  97. (ValidationCurveDisplay, {"param_name": "max_depth", "param_range": [1, 3, 5]}),
  98. (LearningCurveDisplay, {"train_sizes": [0.3, 0.6, 0.9]}),
  99. ],
  100. )
  101. def test_curve_display_negate_score(pyplot, data, CurveDisplay, specific_params):
  102. """Check the behaviour of the `negate_score` parameter calling `from_estimator` and
  103. `plot`.
  104. """
  105. X, y = data
  106. estimator = DecisionTreeClassifier(max_depth=1, random_state=0)
  107. negate_score = False
  108. display = CurveDisplay.from_estimator(
  109. estimator, X, y, **specific_params, negate_score=negate_score
  110. )
  111. positive_scores = display.lines_[0].get_data()[1]
  112. assert (positive_scores >= 0).all()
  113. assert display.ax_.get_ylabel() == "Score"
  114. negate_score = True
  115. display = CurveDisplay.from_estimator(
  116. estimator, X, y, **specific_params, negate_score=negate_score
  117. )
  118. negative_scores = display.lines_[0].get_data()[1]
  119. assert (negative_scores <= 0).all()
  120. assert_allclose(negative_scores, -positive_scores)
  121. assert display.ax_.get_ylabel() == "Negative score"
  122. negate_score = False
  123. display = CurveDisplay.from_estimator(
  124. estimator, X, y, **specific_params, negate_score=negate_score
  125. )
  126. assert display.ax_.get_ylabel() == "Score"
  127. display.plot(negate_score=not negate_score)
  128. assert display.ax_.get_ylabel() == "Score"
  129. assert (display.lines_[0].get_data()[1] < 0).all()
  130. @pytest.mark.parametrize(
  131. "score_name, ylabel", [(None, "Score"), ("Accuracy", "Accuracy")]
  132. )
  133. @pytest.mark.parametrize(
  134. "CurveDisplay, specific_params",
  135. [
  136. (ValidationCurveDisplay, {"param_name": "max_depth", "param_range": [1, 3, 5]}),
  137. (LearningCurveDisplay, {"train_sizes": [0.3, 0.6, 0.9]}),
  138. ],
  139. )
  140. def test_curve_display_score_name(
  141. pyplot, data, score_name, ylabel, CurveDisplay, specific_params
  142. ):
  143. """Check that we can overwrite the default score name shown on the y-axis."""
  144. X, y = data
  145. estimator = DecisionTreeClassifier(random_state=0)
  146. display = CurveDisplay.from_estimator(
  147. estimator, X, y, **specific_params, score_name=score_name
  148. )
  149. assert display.ax_.get_ylabel() == ylabel
  150. X, y = data
  151. estimator = DecisionTreeClassifier(max_depth=1, random_state=0)
  152. display = CurveDisplay.from_estimator(
  153. estimator, X, y, **specific_params, score_name=score_name
  154. )
  155. assert display.score_name == ylabel
  156. @pytest.mark.parametrize("std_display_style", (None, "errorbar"))
  157. def test_learning_curve_display_score_type(pyplot, data, std_display_style):
  158. """Check the behaviour of setting the `score_type` parameter."""
  159. X, y = data
  160. estimator = DecisionTreeClassifier(random_state=0)
  161. train_sizes = [0.3, 0.6, 0.9]
  162. train_sizes_abs, train_scores, test_scores = learning_curve(
  163. estimator, X, y, train_sizes=train_sizes
  164. )
  165. score_type = "train"
  166. display = LearningCurveDisplay.from_estimator(
  167. estimator,
  168. X,
  169. y,
  170. train_sizes=train_sizes,
  171. score_type=score_type,
  172. std_display_style=std_display_style,
  173. )
  174. _, legend_label = display.ax_.get_legend_handles_labels()
  175. assert legend_label == ["Train"]
  176. if std_display_style is None:
  177. assert len(display.lines_) == 1
  178. assert display.errorbar_ is None
  179. x_data, y_data = display.lines_[0].get_data()
  180. else:
  181. assert display.lines_ is None
  182. assert len(display.errorbar_) == 1
  183. x_data, y_data = display.errorbar_[0].lines[0].get_data()
  184. assert_array_equal(x_data, train_sizes_abs)
  185. assert_allclose(y_data, train_scores.mean(axis=1))
  186. score_type = "test"
  187. display = LearningCurveDisplay.from_estimator(
  188. estimator,
  189. X,
  190. y,
  191. train_sizes=train_sizes,
  192. score_type=score_type,
  193. std_display_style=std_display_style,
  194. )
  195. _, legend_label = display.ax_.get_legend_handles_labels()
  196. assert legend_label == ["Test"]
  197. if std_display_style is None:
  198. assert len(display.lines_) == 1
  199. assert display.errorbar_ is None
  200. x_data, y_data = display.lines_[0].get_data()
  201. else:
  202. assert display.lines_ is None
  203. assert len(display.errorbar_) == 1
  204. x_data, y_data = display.errorbar_[0].lines[0].get_data()
  205. assert_array_equal(x_data, train_sizes_abs)
  206. assert_allclose(y_data, test_scores.mean(axis=1))
  207. score_type = "both"
  208. display = LearningCurveDisplay.from_estimator(
  209. estimator,
  210. X,
  211. y,
  212. train_sizes=train_sizes,
  213. score_type=score_type,
  214. std_display_style=std_display_style,
  215. )
  216. _, legend_label = display.ax_.get_legend_handles_labels()
  217. assert legend_label == ["Train", "Test"]
  218. if std_display_style is None:
  219. assert len(display.lines_) == 2
  220. assert display.errorbar_ is None
  221. x_data_train, y_data_train = display.lines_[0].get_data()
  222. x_data_test, y_data_test = display.lines_[1].get_data()
  223. else:
  224. assert display.lines_ is None
  225. assert len(display.errorbar_) == 2
  226. x_data_train, y_data_train = display.errorbar_[0].lines[0].get_data()
  227. x_data_test, y_data_test = display.errorbar_[1].lines[0].get_data()
  228. assert_array_equal(x_data_train, train_sizes_abs)
  229. assert_allclose(y_data_train, train_scores.mean(axis=1))
  230. assert_array_equal(x_data_test, train_sizes_abs)
  231. assert_allclose(y_data_test, test_scores.mean(axis=1))
  232. @pytest.mark.parametrize("std_display_style", (None, "errorbar"))
  233. def test_validation_curve_display_score_type(pyplot, data, std_display_style):
  234. """Check the behaviour of setting the `score_type` parameter."""
  235. X, y = data
  236. estimator = DecisionTreeClassifier(random_state=0)
  237. param_name, param_range = "max_depth", [1, 3, 5]
  238. train_scores, test_scores = validation_curve(
  239. estimator, X, y, param_name=param_name, param_range=param_range
  240. )
  241. score_type = "train"
  242. display = ValidationCurveDisplay.from_estimator(
  243. estimator,
  244. X,
  245. y,
  246. param_name=param_name,
  247. param_range=param_range,
  248. score_type=score_type,
  249. std_display_style=std_display_style,
  250. )
  251. _, legend_label = display.ax_.get_legend_handles_labels()
  252. assert legend_label == ["Train"]
  253. if std_display_style is None:
  254. assert len(display.lines_) == 1
  255. assert display.errorbar_ is None
  256. x_data, y_data = display.lines_[0].get_data()
  257. else:
  258. assert display.lines_ is None
  259. assert len(display.errorbar_) == 1
  260. x_data, y_data = display.errorbar_[0].lines[0].get_data()
  261. assert_array_equal(x_data, param_range)
  262. assert_allclose(y_data, train_scores.mean(axis=1))
  263. score_type = "test"
  264. display = ValidationCurveDisplay.from_estimator(
  265. estimator,
  266. X,
  267. y,
  268. param_name=param_name,
  269. param_range=param_range,
  270. score_type=score_type,
  271. std_display_style=std_display_style,
  272. )
  273. _, legend_label = display.ax_.get_legend_handles_labels()
  274. assert legend_label == ["Test"]
  275. if std_display_style is None:
  276. assert len(display.lines_) == 1
  277. assert display.errorbar_ is None
  278. x_data, y_data = display.lines_[0].get_data()
  279. else:
  280. assert display.lines_ is None
  281. assert len(display.errorbar_) == 1
  282. x_data, y_data = display.errorbar_[0].lines[0].get_data()
  283. assert_array_equal(x_data, param_range)
  284. assert_allclose(y_data, test_scores.mean(axis=1))
  285. score_type = "both"
  286. display = ValidationCurveDisplay.from_estimator(
  287. estimator,
  288. X,
  289. y,
  290. param_name=param_name,
  291. param_range=param_range,
  292. score_type=score_type,
  293. std_display_style=std_display_style,
  294. )
  295. _, legend_label = display.ax_.get_legend_handles_labels()
  296. assert legend_label == ["Train", "Test"]
  297. if std_display_style is None:
  298. assert len(display.lines_) == 2
  299. assert display.errorbar_ is None
  300. x_data_train, y_data_train = display.lines_[0].get_data()
  301. x_data_test, y_data_test = display.lines_[1].get_data()
  302. else:
  303. assert display.lines_ is None
  304. assert len(display.errorbar_) == 2
  305. x_data_train, y_data_train = display.errorbar_[0].lines[0].get_data()
  306. x_data_test, y_data_test = display.errorbar_[1].lines[0].get_data()
  307. assert_array_equal(x_data_train, param_range)
  308. assert_allclose(y_data_train, train_scores.mean(axis=1))
  309. assert_array_equal(x_data_test, param_range)
  310. assert_allclose(y_data_test, test_scores.mean(axis=1))
  311. @pytest.mark.parametrize(
  312. "CurveDisplay, specific_params, expected_xscale",
  313. [
  314. (
  315. ValidationCurveDisplay,
  316. {"param_name": "max_depth", "param_range": np.arange(1, 5)},
  317. "linear",
  318. ),
  319. (LearningCurveDisplay, {"train_sizes": np.linspace(0.1, 0.9, num=5)}, "linear"),
  320. (
  321. ValidationCurveDisplay,
  322. {
  323. "param_name": "max_depth",
  324. "param_range": np.round(np.logspace(0, 2, num=5)).astype(np.int64),
  325. },
  326. "log",
  327. ),
  328. (LearningCurveDisplay, {"train_sizes": np.logspace(-1, 0, num=5)}, "log"),
  329. ],
  330. )
  331. def test_curve_display_xscale_auto(
  332. pyplot, data, CurveDisplay, specific_params, expected_xscale
  333. ):
  334. """Check the behaviour of the x-axis scaling depending on the data provided."""
  335. X, y = data
  336. estimator = DecisionTreeClassifier(random_state=0)
  337. display = CurveDisplay.from_estimator(estimator, X, y, **specific_params)
  338. assert display.ax_.get_xscale() == expected_xscale
  339. @pytest.mark.parametrize(
  340. "CurveDisplay, specific_params",
  341. [
  342. (ValidationCurveDisplay, {"param_name": "max_depth", "param_range": [1, 3, 5]}),
  343. (LearningCurveDisplay, {"train_sizes": [0.3, 0.6, 0.9]}),
  344. ],
  345. )
  346. def test_curve_display_std_display_style(pyplot, data, CurveDisplay, specific_params):
  347. """Check the behaviour of the parameter `std_display_style`."""
  348. X, y = data
  349. estimator = DecisionTreeClassifier(random_state=0)
  350. import matplotlib as mpl
  351. std_display_style = None
  352. display = CurveDisplay.from_estimator(
  353. estimator,
  354. X,
  355. y,
  356. **specific_params,
  357. std_display_style=std_display_style,
  358. )
  359. assert len(display.lines_) == 2
  360. for line in display.lines_:
  361. assert isinstance(line, mpl.lines.Line2D)
  362. assert display.errorbar_ is None
  363. assert display.fill_between_ is None
  364. _, legend_label = display.ax_.get_legend_handles_labels()
  365. assert len(legend_label) == 2
  366. std_display_style = "fill_between"
  367. display = CurveDisplay.from_estimator(
  368. estimator,
  369. X,
  370. y,
  371. **specific_params,
  372. std_display_style=std_display_style,
  373. )
  374. assert len(display.lines_) == 2
  375. for line in display.lines_:
  376. assert isinstance(line, mpl.lines.Line2D)
  377. assert display.errorbar_ is None
  378. assert len(display.fill_between_) == 2
  379. for fill_between in display.fill_between_:
  380. assert isinstance(fill_between, mpl.collections.PolyCollection)
  381. _, legend_label = display.ax_.get_legend_handles_labels()
  382. assert len(legend_label) == 2
  383. std_display_style = "errorbar"
  384. display = CurveDisplay.from_estimator(
  385. estimator,
  386. X,
  387. y,
  388. **specific_params,
  389. std_display_style=std_display_style,
  390. )
  391. assert display.lines_ is None
  392. assert len(display.errorbar_) == 2
  393. for errorbar in display.errorbar_:
  394. assert isinstance(errorbar, mpl.container.ErrorbarContainer)
  395. assert display.fill_between_ is None
  396. _, legend_label = display.ax_.get_legend_handles_labels()
  397. assert len(legend_label) == 2
  398. @pytest.mark.parametrize(
  399. "CurveDisplay, specific_params",
  400. [
  401. (ValidationCurveDisplay, {"param_name": "max_depth", "param_range": [1, 3, 5]}),
  402. (LearningCurveDisplay, {"train_sizes": [0.3, 0.6, 0.9]}),
  403. ],
  404. )
  405. def test_curve_display_plot_kwargs(pyplot, data, CurveDisplay, specific_params):
  406. """Check the behaviour of the different plotting keyword arguments: `line_kw`,
  407. `fill_between_kw`, and `errorbar_kw`."""
  408. X, y = data
  409. estimator = DecisionTreeClassifier(random_state=0)
  410. std_display_style = "fill_between"
  411. line_kw = {"color": "red"}
  412. fill_between_kw = {"color": "red", "alpha": 1.0}
  413. display = CurveDisplay.from_estimator(
  414. estimator,
  415. X,
  416. y,
  417. **specific_params,
  418. std_display_style=std_display_style,
  419. line_kw=line_kw,
  420. fill_between_kw=fill_between_kw,
  421. )
  422. assert display.lines_[0].get_color() == "red"
  423. assert_allclose(
  424. display.fill_between_[0].get_facecolor(),
  425. [[1.0, 0.0, 0.0, 1.0]], # trust me, it's red
  426. )
  427. std_display_style = "errorbar"
  428. errorbar_kw = {"color": "red"}
  429. display = CurveDisplay.from_estimator(
  430. estimator,
  431. X,
  432. y,
  433. **specific_params,
  434. std_display_style=std_display_style,
  435. errorbar_kw=errorbar_kw,
  436. )
  437. assert display.errorbar_[0].lines[0].get_color() == "red"
  438. # TODO(1.5): to be removed
  439. def test_learning_curve_display_deprecate_log_scale(data, pyplot):
  440. """Check that we warn for the deprecated parameter `log_scale`."""
  441. X, y = data
  442. estimator = DecisionTreeClassifier(random_state=0)
  443. with pytest.warns(FutureWarning, match="`log_scale` parameter is deprecated"):
  444. display = LearningCurveDisplay.from_estimator(
  445. estimator, X, y, train_sizes=[0.3, 0.6, 0.9], log_scale=True
  446. )
  447. assert display.ax_.get_xscale() == "log"
  448. assert display.ax_.get_yscale() == "linear"
  449. with pytest.warns(FutureWarning, match="`log_scale` parameter is deprecated"):
  450. display = LearningCurveDisplay.from_estimator(
  451. estimator, X, y, train_sizes=[0.3, 0.6, 0.9], log_scale=False
  452. )
  453. assert display.ax_.get_xscale() == "linear"
  454. assert display.ax_.get_yscale() == "linear"
  455. @pytest.mark.parametrize(
  456. "param_range, xscale",
  457. [([5, 10, 15], "linear"), ([-50, 5, 50, 500], "symlog"), ([5, 50, 500], "log")],
  458. )
  459. def test_validation_curve_xscale_from_param_range_provided_as_a_list(
  460. pyplot, data, param_range, xscale
  461. ):
  462. """Check the induced xscale from the provided param_range values."""
  463. X, y = data
  464. estimator = DecisionTreeClassifier(random_state=0)
  465. param_name = "max_depth"
  466. display = ValidationCurveDisplay.from_estimator(
  467. estimator,
  468. X,
  469. y,
  470. param_name=param_name,
  471. param_range=param_range,
  472. )
  473. assert display.ax_.get_xscale() == xscale