test_export.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. """
  2. Testing for export functions of decision trees (sklearn.tree.export).
  3. """
  4. from io import StringIO
  5. from re import finditer, search
  6. from textwrap import dedent
  7. import numpy as np
  8. import pytest
  9. from numpy.random import RandomState
  10. from sklearn.base import is_classifier
  11. from sklearn.ensemble import GradientBoostingClassifier
  12. from sklearn.exceptions import NotFittedError
  13. from sklearn.tree import (
  14. DecisionTreeClassifier,
  15. DecisionTreeRegressor,
  16. export_graphviz,
  17. export_text,
  18. plot_tree,
  19. )
  20. # toy sample
  21. X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
  22. y = [-1, -1, -1, 1, 1, 1]
  23. y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
  24. w = [1, 1, 1, 0.5, 0.5, 0.5]
  25. y_degraded = [1, 1, 1, 1, 1, 1]
  26. def test_graphviz_toy():
  27. # Check correctness of export_graphviz
  28. clf = DecisionTreeClassifier(
  29. max_depth=3, min_samples_split=2, criterion="gini", random_state=2
  30. )
  31. clf.fit(X, y)
  32. # Test export code
  33. contents1 = export_graphviz(clf, out_file=None)
  34. contents2 = (
  35. "digraph Tree {\n"
  36. 'node [shape=box, fontname="helvetica"] ;\n'
  37. 'edge [fontname="helvetica"] ;\n'
  38. '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
  39. 'value = [3, 3]"] ;\n'
  40. '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
  41. "0 -> 1 [labeldistance=2.5, labelangle=45, "
  42. 'headlabel="True"] ;\n'
  43. '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
  44. "0 -> 2 [labeldistance=2.5, labelangle=-45, "
  45. 'headlabel="False"] ;\n'
  46. "}"
  47. )
  48. assert contents1 == contents2
  49. # Test plot_options
  50. contents1 = export_graphviz(
  51. clf,
  52. filled=True,
  53. impurity=False,
  54. proportion=True,
  55. special_characters=True,
  56. rounded=True,
  57. out_file=None,
  58. fontname="sans",
  59. )
  60. contents2 = (
  61. "digraph Tree {\n"
  62. 'node [shape=box, style="filled, rounded", color="black", '
  63. 'fontname="sans"] ;\n'
  64. 'edge [fontname="sans"] ;\n'
  65. "0 [label=<x<SUB>0</SUB> &le; 0.0<br/>samples = 100.0%<br/>"
  66. 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
  67. "1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
  68. 'fillcolor="#e58139"] ;\n'
  69. "0 -> 1 [labeldistance=2.5, labelangle=45, "
  70. 'headlabel="True"] ;\n'
  71. "2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
  72. 'fillcolor="#399de5"] ;\n'
  73. "0 -> 2 [labeldistance=2.5, labelangle=-45, "
  74. 'headlabel="False"] ;\n'
  75. "}"
  76. )
  77. assert contents1 == contents2
  78. # Test max_depth
  79. contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
  80. contents2 = (
  81. "digraph Tree {\n"
  82. 'node [shape=box, fontname="helvetica"] ;\n'
  83. 'edge [fontname="helvetica"] ;\n'
  84. '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
  85. 'value = [3, 3]\\nclass = y[0]"] ;\n'
  86. '1 [label="(...)"] ;\n'
  87. "0 -> 1 ;\n"
  88. '2 [label="(...)"] ;\n'
  89. "0 -> 2 ;\n"
  90. "}"
  91. )
  92. assert contents1 == contents2
  93. # Test max_depth with plot_options
  94. contents1 = export_graphviz(
  95. clf, max_depth=0, filled=True, out_file=None, node_ids=True
  96. )
  97. contents2 = (
  98. "digraph Tree {\n"
  99. 'node [shape=box, style="filled", color="black", '
  100. 'fontname="helvetica"] ;\n'
  101. 'edge [fontname="helvetica"] ;\n'
  102. '0 [label="node #0\\nx[0] <= 0.0\\ngini = 0.5\\n'
  103. 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
  104. '1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
  105. "0 -> 1 ;\n"
  106. '2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
  107. "0 -> 2 ;\n"
  108. "}"
  109. )
  110. assert contents1 == contents2
  111. # Test multi-output with weighted samples
  112. clf = DecisionTreeClassifier(
  113. max_depth=2, min_samples_split=2, criterion="gini", random_state=2
  114. )
  115. clf = clf.fit(X, y2, sample_weight=w)
  116. contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
  117. contents2 = (
  118. "digraph Tree {\n"
  119. 'node [shape=box, style="filled", color="black", '
  120. 'fontname="helvetica"] ;\n'
  121. 'edge [fontname="helvetica"] ;\n'
  122. '0 [label="x[0] <= 0.0\\nsamples = 6\\n'
  123. "value = [[3.0, 1.5, 0.0]\\n"
  124. '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
  125. '1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
  126. '[3, 0, 0]]", fillcolor="#e58139"] ;\n'
  127. "0 -> 1 [labeldistance=2.5, labelangle=45, "
  128. 'headlabel="True"] ;\n'
  129. '2 [label="x[0] <= 1.5\\nsamples = 3\\n'
  130. "value = [[0.0, 1.5, 0.0]\\n"
  131. '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
  132. "0 -> 2 [labeldistance=2.5, labelangle=-45, "
  133. 'headlabel="False"] ;\n'
  134. '3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
  135. '[0, 1, 0]]", fillcolor="#e58139"] ;\n'
  136. "2 -> 3 ;\n"
  137. '4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
  138. '[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
  139. "2 -> 4 ;\n"
  140. "}"
  141. )
  142. assert contents1 == contents2
  143. # Test regression output with plot_options
  144. clf = DecisionTreeRegressor(
  145. max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
  146. )
  147. clf.fit(X, y)
  148. contents1 = export_graphviz(
  149. clf,
  150. filled=True,
  151. leaves_parallel=True,
  152. out_file=None,
  153. rotate=True,
  154. rounded=True,
  155. fontname="sans",
  156. )
  157. contents2 = (
  158. "digraph Tree {\n"
  159. 'node [shape=box, style="filled, rounded", color="black", '
  160. 'fontname="sans"] ;\n'
  161. "graph [ranksep=equally, splines=polyline] ;\n"
  162. 'edge [fontname="sans"] ;\n'
  163. "rankdir=LR ;\n"
  164. '0 [label="x[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
  165. 'value = 0.0", fillcolor="#f2c09c"] ;\n'
  166. '1 [label="squared_error = 0.0\\nsamples = 3\\'
  167. 'nvalue = -1.0", '
  168. 'fillcolor="#ffffff"] ;\n'
  169. "0 -> 1 [labeldistance=2.5, labelangle=-45, "
  170. 'headlabel="True"] ;\n'
  171. '2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
  172. 'fillcolor="#e58139"] ;\n'
  173. "0 -> 2 [labeldistance=2.5, labelangle=45, "
  174. 'headlabel="False"] ;\n'
  175. "{rank=same ; 0} ;\n"
  176. "{rank=same ; 1; 2} ;\n"
  177. "}"
  178. )
  179. assert contents1 == contents2
  180. # Test classifier with degraded learning set
  181. clf = DecisionTreeClassifier(max_depth=3)
  182. clf.fit(X, y_degraded)
  183. contents1 = export_graphviz(clf, filled=True, out_file=None)
  184. contents2 = (
  185. "digraph Tree {\n"
  186. 'node [shape=box, style="filled", color="black", '
  187. 'fontname="helvetica"] ;\n'
  188. 'edge [fontname="helvetica"] ;\n'
  189. '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
  190. 'fillcolor="#ffffff"] ;\n'
  191. "}"
  192. )
  193. @pytest.mark.parametrize("constructor", [list, np.array])
  194. def test_graphviz_feature_class_names_array_support(constructor):
  195. # Check that export_graphviz treats feature names
  196. # and class names correctly and supports arrays
  197. clf = DecisionTreeClassifier(
  198. max_depth=3, min_samples_split=2, criterion="gini", random_state=2
  199. )
  200. clf.fit(X, y)
  201. # Test with feature_names
  202. contents1 = export_graphviz(
  203. clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
  204. )
  205. contents2 = (
  206. "digraph Tree {\n"
  207. 'node [shape=box, fontname="helvetica"] ;\n'
  208. 'edge [fontname="helvetica"] ;\n'
  209. '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
  210. 'value = [3, 3]"] ;\n'
  211. '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
  212. "0 -> 1 [labeldistance=2.5, labelangle=45, "
  213. 'headlabel="True"] ;\n'
  214. '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
  215. "0 -> 2 [labeldistance=2.5, labelangle=-45, "
  216. 'headlabel="False"] ;\n'
  217. "}"
  218. )
  219. assert contents1 == contents2
  220. # Test with class_names
  221. contents1 = export_graphviz(
  222. clf, class_names=constructor(["yes", "no"]), out_file=None
  223. )
  224. contents2 = (
  225. "digraph Tree {\n"
  226. 'node [shape=box, fontname="helvetica"] ;\n'
  227. 'edge [fontname="helvetica"] ;\n'
  228. '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
  229. 'value = [3, 3]\\nclass = yes"] ;\n'
  230. '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
  231. 'class = yes"] ;\n'
  232. "0 -> 1 [labeldistance=2.5, labelangle=45, "
  233. 'headlabel="True"] ;\n'
  234. '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
  235. 'class = no"] ;\n'
  236. "0 -> 2 [labeldistance=2.5, labelangle=-45, "
  237. 'headlabel="False"] ;\n'
  238. "}"
  239. )
  240. assert contents1 == contents2
  241. def test_graphviz_errors():
  242. # Check for errors of export_graphviz
  243. clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
  244. # Check not-fitted decision tree error
  245. out = StringIO()
  246. with pytest.raises(NotFittedError):
  247. export_graphviz(clf, out)
  248. clf.fit(X, y)
  249. # Check if it errors when length of feature_names
  250. # mismatches with number of features
  251. message = "Length of feature_names, 1 does not match number of features, 2"
  252. with pytest.raises(ValueError, match=message):
  253. export_graphviz(clf, None, feature_names=["a"])
  254. message = "Length of feature_names, 3 does not match number of features, 2"
  255. with pytest.raises(ValueError, match=message):
  256. export_graphviz(clf, None, feature_names=["a", "b", "c"])
  257. # Check error when argument is not an estimator
  258. message = "is not an estimator instance"
  259. with pytest.raises(TypeError, match=message):
  260. export_graphviz(clf.fit(X, y).tree_)
  261. # Check class_names error
  262. out = StringIO()
  263. with pytest.raises(IndexError):
  264. export_graphviz(clf, out, class_names=[])
  265. def test_friedman_mse_in_graphviz():
  266. clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
  267. clf.fit(X, y)
  268. dot_data = StringIO()
  269. export_graphviz(clf, out_file=dot_data)
  270. clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
  271. clf.fit(X, y)
  272. for estimator in clf.estimators_:
  273. export_graphviz(estimator[0], out_file=dot_data)
  274. for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
  275. assert "friedman_mse" in finding.group()
  276. def test_precision():
  277. rng_reg = RandomState(2)
  278. rng_clf = RandomState(8)
  279. for X, y, clf in zip(
  280. (rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
  281. (rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
  282. (
  283. DecisionTreeRegressor(
  284. criterion="friedman_mse", random_state=0, max_depth=1
  285. ),
  286. DecisionTreeClassifier(max_depth=1, random_state=0),
  287. ),
  288. ):
  289. clf.fit(X, y)
  290. for precision in (4, 3):
  291. dot_data = export_graphviz(
  292. clf, out_file=None, precision=precision, proportion=True
  293. )
  294. # With the current random state, the impurity and the threshold
  295. # will have the number of precision set in the export_graphviz
  296. # function. We will check the number of precision with a strict
  297. # equality. The value reported will have only 2 precision and
  298. # therefore, only a less equal comparison will be done.
  299. # check value
  300. for finding in finditer(r"value = \d+\.\d+", dot_data):
  301. assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
  302. # check impurity
  303. if is_classifier(clf):
  304. pattern = r"gini = \d+\.\d+"
  305. else:
  306. pattern = r"friedman_mse = \d+\.\d+"
  307. # check impurity
  308. for finding in finditer(pattern, dot_data):
  309. assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
  310. # check threshold
  311. for finding in finditer(r"<= \d+\.\d+", dot_data):
  312. assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
  313. def test_export_text_errors():
  314. clf = DecisionTreeClassifier(max_depth=2, random_state=0)
  315. clf.fit(X, y)
  316. err_msg = "feature_names must contain 2 elements, got 1"
  317. with pytest.raises(ValueError, match=err_msg):
  318. export_text(clf, feature_names=["a"])
  319. err_msg = (
  320. "When `class_names` is an array, it should contain as"
  321. " many items as `decision_tree.classes_`. Got 1 while"
  322. " the tree was fitted with 2 classes."
  323. )
  324. with pytest.raises(ValueError, match=err_msg):
  325. export_text(clf, class_names=["a"])
  326. def test_export_text():
  327. clf = DecisionTreeClassifier(max_depth=2, random_state=0)
  328. clf.fit(X, y)
  329. expected_report = dedent("""
  330. |--- feature_1 <= 0.00
  331. | |--- class: -1
  332. |--- feature_1 > 0.00
  333. | |--- class: 1
  334. """).lstrip()
  335. assert export_text(clf) == expected_report
  336. # testing that leaves at level 1 are not truncated
  337. assert export_text(clf, max_depth=0) == expected_report
  338. # testing that the rest of the tree is truncated
  339. assert export_text(clf, max_depth=10) == expected_report
  340. expected_report = dedent("""
  341. |--- feature_1 <= 0.00
  342. | |--- weights: [3.00, 0.00] class: -1
  343. |--- feature_1 > 0.00
  344. | |--- weights: [0.00, 3.00] class: 1
  345. """).lstrip()
  346. assert export_text(clf, show_weights=True) == expected_report
  347. expected_report = dedent("""
  348. |- feature_1 <= 0.00
  349. | |- class: -1
  350. |- feature_1 > 0.00
  351. | |- class: 1
  352. """).lstrip()
  353. assert export_text(clf, spacing=1) == expected_report
  354. X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
  355. y_l = [-1, -1, -1, 1, 1, 1, 2]
  356. clf = DecisionTreeClassifier(max_depth=4, random_state=0)
  357. clf.fit(X_l, y_l)
  358. expected_report = dedent("""
  359. |--- feature_1 <= 0.00
  360. | |--- class: -1
  361. |--- feature_1 > 0.00
  362. | |--- truncated branch of depth 2
  363. """).lstrip()
  364. assert export_text(clf, max_depth=0) == expected_report
  365. X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
  366. y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
  367. reg = DecisionTreeRegressor(max_depth=2, random_state=0)
  368. reg.fit(X_mo, y_mo)
  369. expected_report = dedent("""
  370. |--- feature_1 <= 0.0
  371. | |--- value: [-1.0, -1.0]
  372. |--- feature_1 > 0.0
  373. | |--- value: [1.0, 1.0]
  374. """).lstrip()
  375. assert export_text(reg, decimals=1) == expected_report
  376. assert export_text(reg, decimals=1, show_weights=True) == expected_report
  377. X_single = [[-2], [-1], [-1], [1], [1], [2]]
  378. reg = DecisionTreeRegressor(max_depth=2, random_state=0)
  379. reg.fit(X_single, y_mo)
  380. expected_report = dedent("""
  381. |--- first <= 0.0
  382. | |--- value: [-1.0, -1.0]
  383. |--- first > 0.0
  384. | |--- value: [1.0, 1.0]
  385. """).lstrip()
  386. assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
  387. assert (
  388. export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
  389. == expected_report
  390. )
  391. @pytest.mark.parametrize("constructor", [list, np.array])
  392. def test_export_text_feature_class_names_array_support(constructor):
  393. # Check that export_graphviz treats feature names
  394. # and class names correctly and supports arrays
  395. clf = DecisionTreeClassifier(max_depth=2, random_state=0)
  396. clf.fit(X, y)
  397. expected_report = dedent("""
  398. |--- b <= 0.00
  399. | |--- class: -1
  400. |--- b > 0.00
  401. | |--- class: 1
  402. """).lstrip()
  403. assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
  404. expected_report = dedent("""
  405. |--- feature_1 <= 0.00
  406. | |--- class: cat
  407. |--- feature_1 > 0.00
  408. | |--- class: dog
  409. """).lstrip()
  410. assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
  411. def test_plot_tree_entropy(pyplot):
  412. # mostly smoke tests
  413. # Check correctness of export_graphviz for criterion = entropy
  414. clf = DecisionTreeClassifier(
  415. max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
  416. )
  417. clf.fit(X, y)
  418. # Test export code
  419. feature_names = ["first feat", "sepal_width"]
  420. nodes = plot_tree(clf, feature_names=feature_names)
  421. assert len(nodes) == 3
  422. assert (
  423. nodes[0].get_text()
  424. == "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
  425. )
  426. assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
  427. assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
  428. def test_plot_tree_gini(pyplot):
  429. # mostly smoke tests
  430. # Check correctness of export_graphviz for criterion = gini
  431. clf = DecisionTreeClassifier(
  432. max_depth=3, min_samples_split=2, criterion="gini", random_state=2
  433. )
  434. clf.fit(X, y)
  435. # Test export code
  436. feature_names = ["first feat", "sepal_width"]
  437. nodes = plot_tree(clf, feature_names=feature_names)
  438. assert len(nodes) == 3
  439. assert (
  440. nodes[0].get_text()
  441. == "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
  442. )
  443. assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
  444. assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
  445. def test_not_fitted_tree(pyplot):
  446. # Testing if not fitted tree throws the correct error
  447. clf = DecisionTreeRegressor()
  448. with pytest.raises(NotFittedError):
  449. plot_tree(clf)