| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519 |
- """
- Testing for export functions of decision trees (sklearn.tree.export).
- """
- from io import StringIO
- from re import finditer, search
- from textwrap import dedent
- import numpy as np
- import pytest
- from numpy.random import RandomState
- from sklearn.base import is_classifier
- from sklearn.ensemble import GradientBoostingClassifier
- from sklearn.exceptions import NotFittedError
- from sklearn.tree import (
- DecisionTreeClassifier,
- DecisionTreeRegressor,
- export_graphviz,
- export_text,
- plot_tree,
- )
- # toy sample
- X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
- y = [-1, -1, -1, 1, 1, 1]
- y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
- w = [1, 1, 1, 0.5, 0.5, 0.5]
- y_degraded = [1, 1, 1, 1, 1, 1]
- def test_graphviz_toy():
- # Check correctness of export_graphviz
- clf = DecisionTreeClassifier(
- max_depth=3, min_samples_split=2, criterion="gini", random_state=2
- )
- clf.fit(X, y)
- # Test export code
- contents1 = export_graphviz(clf, out_file=None)
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
- 'value = [3, 3]"] ;\n'
- '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
- 'headlabel="True"] ;\n'
- '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="False"] ;\n'
- "}"
- )
- assert contents1 == contents2
- # Test plot_options
- contents1 = export_graphviz(
- clf,
- filled=True,
- impurity=False,
- proportion=True,
- special_characters=True,
- rounded=True,
- out_file=None,
- fontname="sans",
- )
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, style="filled, rounded", color="black", '
- 'fontname="sans"] ;\n'
- 'edge [fontname="sans"] ;\n'
- "0 [label=<x<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>"
- 'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
- "1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
- 'fillcolor="#e58139"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
- 'headlabel="True"] ;\n'
- "2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
- 'fillcolor="#399de5"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="False"] ;\n'
- "}"
- )
- assert contents1 == contents2
- # Test max_depth
- contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
- 'value = [3, 3]\\nclass = y[0]"] ;\n'
- '1 [label="(...)"] ;\n'
- "0 -> 1 ;\n"
- '2 [label="(...)"] ;\n'
- "0 -> 2 ;\n"
- "}"
- )
- assert contents1 == contents2
- # Test max_depth with plot_options
- contents1 = export_graphviz(
- clf, max_depth=0, filled=True, out_file=None, node_ids=True
- )
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, style="filled", color="black", '
- 'fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="node #0\\nx[0] <= 0.0\\ngini = 0.5\\n'
- 'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
- '1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
- "0 -> 1 ;\n"
- '2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
- "0 -> 2 ;\n"
- "}"
- )
- assert contents1 == contents2
- # Test multi-output with weighted samples
- clf = DecisionTreeClassifier(
- max_depth=2, min_samples_split=2, criterion="gini", random_state=2
- )
- clf = clf.fit(X, y2, sample_weight=w)
- contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, style="filled", color="black", '
- 'fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="x[0] <= 0.0\\nsamples = 6\\n'
- "value = [[3.0, 1.5, 0.0]\\n"
- '[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
- '1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
- '[3, 0, 0]]", fillcolor="#e58139"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
- 'headlabel="True"] ;\n'
- '2 [label="x[0] <= 1.5\\nsamples = 3\\n'
- "value = [[0.0, 1.5, 0.0]\\n"
- '[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="False"] ;\n'
- '3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
- '[0, 1, 0]]", fillcolor="#e58139"] ;\n'
- "2 -> 3 ;\n"
- '4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
- '[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
- "2 -> 4 ;\n"
- "}"
- )
- assert contents1 == contents2
- # Test regression output with plot_options
- clf = DecisionTreeRegressor(
- max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
- )
- clf.fit(X, y)
- contents1 = export_graphviz(
- clf,
- filled=True,
- leaves_parallel=True,
- out_file=None,
- rotate=True,
- rounded=True,
- fontname="sans",
- )
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, style="filled, rounded", color="black", '
- 'fontname="sans"] ;\n'
- "graph [ranksep=equally, splines=polyline] ;\n"
- 'edge [fontname="sans"] ;\n'
- "rankdir=LR ;\n"
- '0 [label="x[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
- 'value = 0.0", fillcolor="#f2c09c"] ;\n'
- '1 [label="squared_error = 0.0\\nsamples = 3\\'
- 'nvalue = -1.0", '
- 'fillcolor="#ffffff"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="True"] ;\n'
- '2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
- 'fillcolor="#e58139"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=45, "
- 'headlabel="False"] ;\n'
- "{rank=same ; 0} ;\n"
- "{rank=same ; 1; 2} ;\n"
- "}"
- )
- assert contents1 == contents2
- # Test classifier with degraded learning set
- clf = DecisionTreeClassifier(max_depth=3)
- clf.fit(X, y_degraded)
- contents1 = export_graphviz(clf, filled=True, out_file=None)
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, style="filled", color="black", '
- 'fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
- 'fillcolor="#ffffff"] ;\n'
- "}"
- )
- @pytest.mark.parametrize("constructor", [list, np.array])
- def test_graphviz_feature_class_names_array_support(constructor):
- # Check that export_graphviz treats feature names
- # and class names correctly and supports arrays
- clf = DecisionTreeClassifier(
- max_depth=3, min_samples_split=2, criterion="gini", random_state=2
- )
- clf.fit(X, y)
- # Test with feature_names
- contents1 = export_graphviz(
- clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
- )
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
- 'value = [3, 3]"] ;\n'
- '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
- 'headlabel="True"] ;\n'
- '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="False"] ;\n'
- "}"
- )
- assert contents1 == contents2
- # Test with class_names
- contents1 = export_graphviz(
- clf, class_names=constructor(["yes", "no"]), out_file=None
- )
- contents2 = (
- "digraph Tree {\n"
- 'node [shape=box, fontname="helvetica"] ;\n'
- 'edge [fontname="helvetica"] ;\n'
- '0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
- 'value = [3, 3]\\nclass = yes"] ;\n'
- '1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
- 'class = yes"] ;\n'
- "0 -> 1 [labeldistance=2.5, labelangle=45, "
- 'headlabel="True"] ;\n'
- '2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
- 'class = no"] ;\n'
- "0 -> 2 [labeldistance=2.5, labelangle=-45, "
- 'headlabel="False"] ;\n'
- "}"
- )
- assert contents1 == contents2
- def test_graphviz_errors():
- # Check for errors of export_graphviz
- clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
- # Check not-fitted decision tree error
- out = StringIO()
- with pytest.raises(NotFittedError):
- export_graphviz(clf, out)
- clf.fit(X, y)
- # Check if it errors when length of feature_names
- # mismatches with number of features
- message = "Length of feature_names, 1 does not match number of features, 2"
- with pytest.raises(ValueError, match=message):
- export_graphviz(clf, None, feature_names=["a"])
- message = "Length of feature_names, 3 does not match number of features, 2"
- with pytest.raises(ValueError, match=message):
- export_graphviz(clf, None, feature_names=["a", "b", "c"])
- # Check error when argument is not an estimator
- message = "is not an estimator instance"
- with pytest.raises(TypeError, match=message):
- export_graphviz(clf.fit(X, y).tree_)
- # Check class_names error
- out = StringIO()
- with pytest.raises(IndexError):
- export_graphviz(clf, out, class_names=[])
- def test_friedman_mse_in_graphviz():
- clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
- clf.fit(X, y)
- dot_data = StringIO()
- export_graphviz(clf, out_file=dot_data)
- clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
- clf.fit(X, y)
- for estimator in clf.estimators_:
- export_graphviz(estimator[0], out_file=dot_data)
- for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
- assert "friedman_mse" in finding.group()
- def test_precision():
- rng_reg = RandomState(2)
- rng_clf = RandomState(8)
- for X, y, clf in zip(
- (rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
- (rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
- (
- DecisionTreeRegressor(
- criterion="friedman_mse", random_state=0, max_depth=1
- ),
- DecisionTreeClassifier(max_depth=1, random_state=0),
- ),
- ):
- clf.fit(X, y)
- for precision in (4, 3):
- dot_data = export_graphviz(
- clf, out_file=None, precision=precision, proportion=True
- )
- # With the current random state, the impurity and the threshold
- # will have the number of precision set in the export_graphviz
- # function. We will check the number of precision with a strict
- # equality. The value reported will have only 2 precision and
- # therefore, only a less equal comparison will be done.
- # check value
- for finding in finditer(r"value = \d+\.\d+", dot_data):
- assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
- # check impurity
- if is_classifier(clf):
- pattern = r"gini = \d+\.\d+"
- else:
- pattern = r"friedman_mse = \d+\.\d+"
- # check impurity
- for finding in finditer(pattern, dot_data):
- assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
- # check threshold
- for finding in finditer(r"<= \d+\.\d+", dot_data):
- assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
- def test_export_text_errors():
- clf = DecisionTreeClassifier(max_depth=2, random_state=0)
- clf.fit(X, y)
- err_msg = "feature_names must contain 2 elements, got 1"
- with pytest.raises(ValueError, match=err_msg):
- export_text(clf, feature_names=["a"])
- err_msg = (
- "When `class_names` is an array, it should contain as"
- " many items as `decision_tree.classes_`. Got 1 while"
- " the tree was fitted with 2 classes."
- )
- with pytest.raises(ValueError, match=err_msg):
- export_text(clf, class_names=["a"])
- def test_export_text():
- clf = DecisionTreeClassifier(max_depth=2, random_state=0)
- clf.fit(X, y)
- expected_report = dedent("""
- |--- feature_1 <= 0.00
- | |--- class: -1
- |--- feature_1 > 0.00
- | |--- class: 1
- """).lstrip()
- assert export_text(clf) == expected_report
- # testing that leaves at level 1 are not truncated
- assert export_text(clf, max_depth=0) == expected_report
- # testing that the rest of the tree is truncated
- assert export_text(clf, max_depth=10) == expected_report
- expected_report = dedent("""
- |--- feature_1 <= 0.00
- | |--- weights: [3.00, 0.00] class: -1
- |--- feature_1 > 0.00
- | |--- weights: [0.00, 3.00] class: 1
- """).lstrip()
- assert export_text(clf, show_weights=True) == expected_report
- expected_report = dedent("""
- |- feature_1 <= 0.00
- | |- class: -1
- |- feature_1 > 0.00
- | |- class: 1
- """).lstrip()
- assert export_text(clf, spacing=1) == expected_report
- X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
- y_l = [-1, -1, -1, 1, 1, 1, 2]
- clf = DecisionTreeClassifier(max_depth=4, random_state=0)
- clf.fit(X_l, y_l)
- expected_report = dedent("""
- |--- feature_1 <= 0.00
- | |--- class: -1
- |--- feature_1 > 0.00
- | |--- truncated branch of depth 2
- """).lstrip()
- assert export_text(clf, max_depth=0) == expected_report
- X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
- y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
- reg = DecisionTreeRegressor(max_depth=2, random_state=0)
- reg.fit(X_mo, y_mo)
- expected_report = dedent("""
- |--- feature_1 <= 0.0
- | |--- value: [-1.0, -1.0]
- |--- feature_1 > 0.0
- | |--- value: [1.0, 1.0]
- """).lstrip()
- assert export_text(reg, decimals=1) == expected_report
- assert export_text(reg, decimals=1, show_weights=True) == expected_report
- X_single = [[-2], [-1], [-1], [1], [1], [2]]
- reg = DecisionTreeRegressor(max_depth=2, random_state=0)
- reg.fit(X_single, y_mo)
- expected_report = dedent("""
- |--- first <= 0.0
- | |--- value: [-1.0, -1.0]
- |--- first > 0.0
- | |--- value: [1.0, 1.0]
- """).lstrip()
- assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
- assert (
- export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
- == expected_report
- )
- @pytest.mark.parametrize("constructor", [list, np.array])
- def test_export_text_feature_class_names_array_support(constructor):
- # Check that export_graphviz treats feature names
- # and class names correctly and supports arrays
- clf = DecisionTreeClassifier(max_depth=2, random_state=0)
- clf.fit(X, y)
- expected_report = dedent("""
- |--- b <= 0.00
- | |--- class: -1
- |--- b > 0.00
- | |--- class: 1
- """).lstrip()
- assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
- expected_report = dedent("""
- |--- feature_1 <= 0.00
- | |--- class: cat
- |--- feature_1 > 0.00
- | |--- class: dog
- """).lstrip()
- assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
- def test_plot_tree_entropy(pyplot):
- # mostly smoke tests
- # Check correctness of export_graphviz for criterion = entropy
- clf = DecisionTreeClassifier(
- max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
- )
- clf.fit(X, y)
- # Test export code
- feature_names = ["first feat", "sepal_width"]
- nodes = plot_tree(clf, feature_names=feature_names)
- assert len(nodes) == 3
- assert (
- nodes[0].get_text()
- == "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
- )
- assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
- assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
- def test_plot_tree_gini(pyplot):
- # mostly smoke tests
- # Check correctness of export_graphviz for criterion = gini
- clf = DecisionTreeClassifier(
- max_depth=3, min_samples_split=2, criterion="gini", random_state=2
- )
- clf.fit(X, y)
- # Test export code
- feature_names = ["first feat", "sepal_width"]
- nodes = plot_tree(clf, feature_names=feature_names)
- assert len(nodes) == 3
- assert (
- nodes[0].get_text()
- == "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
- )
- assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
- assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
- def test_not_fitted_tree(pyplot):
- # Testing if not fitted tree throws the correct error
- clf = DecisionTreeRegressor()
- with pytest.raises(NotFittedError):
- plot_tree(clf)
|