| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126 |
- """
- This module defines export functions for decision trees.
- """
- # Authors: Gilles Louppe <g.louppe@gmail.com>
- # Peter Prettenhofer <peter.prettenhofer@gmail.com>
- # Brian Holt <bdholt1@gmail.com>
- # Noel Dawe <noel@dawe.me>
- # Satrajit Gosh <satrajit.ghosh@gmail.com>
- # Trevor Stephens <trev.stephens@gmail.com>
- # Li Li <aiki.nogard@gmail.com>
- # Giuseppe Vettigli <vettigli@gmail.com>
- # License: BSD 3 clause
- from collections.abc import Iterable
- from io import StringIO
- from numbers import Integral
- import numpy as np
- from ..base import is_classifier
- from ..utils._param_validation import HasMethods, Interval, StrOptions, validate_params
- from ..utils.validation import check_array, check_is_fitted
- from . import DecisionTreeClassifier, DecisionTreeRegressor, _criterion, _tree
- from ._reingold_tilford import Tree, buchheim
- def _color_brew(n):
- """Generate n colors with equally spaced hues.
- Parameters
- ----------
- n : int
- The number of colors required.
- Returns
- -------
- color_list : list, length n
- List of n tuples of form (R, G, B) being the components of each color.
- """
- color_list = []
- # Initialize saturation & value; calculate chroma & value shift
- s, v = 0.75, 0.9
- c = s * v
- m = v - c
- for h in np.arange(25, 385, 360.0 / n).astype(int):
- # Calculate some intermediate values
- h_bar = h / 60.0
- x = c * (1 - abs((h_bar % 2) - 1))
- # Initialize RGB with same hue & chroma as our color
- rgb = [
- (c, x, 0),
- (x, c, 0),
- (0, c, x),
- (0, x, c),
- (x, 0, c),
- (c, 0, x),
- (c, x, 0),
- ]
- r, g, b = rgb[int(h_bar)]
- # Shift the initial RGB values to match value and store
- rgb = [(int(255 * (r + m))), (int(255 * (g + m))), (int(255 * (b + m)))]
- color_list.append(rgb)
- return color_list
- class Sentinel:
- def __repr__(self):
- return '"tree.dot"'
- SENTINEL = Sentinel()
- @validate_params(
- {
- "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
- "max_depth": [Interval(Integral, 0, None, closed="left"), None],
- "feature_names": ["array-like", None],
- "class_names": ["array-like", "boolean", None],
- "label": [StrOptions({"all", "root", "none"})],
- "filled": ["boolean"],
- "impurity": ["boolean"],
- "node_ids": ["boolean"],
- "proportion": ["boolean"],
- "rounded": ["boolean"],
- "precision": [Interval(Integral, 0, None, closed="left"), None],
- "ax": "no_validation", # delegate validation to matplotlib
- "fontsize": [Interval(Integral, 0, None, closed="left"), None],
- },
- prefer_skip_nested_validation=True,
- )
- def plot_tree(
- decision_tree,
- *,
- max_depth=None,
- feature_names=None,
- class_names=None,
- label="all",
- filled=False,
- impurity=True,
- node_ids=False,
- proportion=False,
- rounded=False,
- precision=3,
- ax=None,
- fontsize=None,
- ):
- """Plot a decision tree.
- The sample counts that are shown are weighted with any sample_weights that
- might be present.
- The visualization is fit automatically to the size of the axis.
- Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control
- the size of the rendering.
- Read more in the :ref:`User Guide <tree>`.
- .. versionadded:: 0.21
- Parameters
- ----------
- decision_tree : decision tree regressor or classifier
- The decision tree to be plotted.
- max_depth : int, default=None
- The maximum depth of the representation. If None, the tree is fully
- generated.
- feature_names : array-like of str, default=None
- Names of each of the features.
- If None, generic names will be used ("x[0]", "x[1]", ...).
- class_names : array-like of str or True, default=None
- Names of each of the target classes in ascending numerical order.
- Only relevant for classification and not supported for multi-output.
- If ``True``, shows a symbolic representation of the class name.
- label : {'all', 'root', 'none'}, default='all'
- Whether to show informative labels for impurity, etc.
- Options include 'all' to show at every node, 'root' to show only at
- the top root node, or 'none' to not show at any node.
- filled : bool, default=False
- When set to ``True``, paint nodes to indicate majority class for
- classification, extremity of values for regression, or purity of node
- for multi-output.
- impurity : bool, default=True
- When set to ``True``, show the impurity at each node.
- node_ids : bool, default=False
- When set to ``True``, show the ID number on each node.
- proportion : bool, default=False
- When set to ``True``, change the display of 'values' and/or 'samples'
- to be proportions and percentages respectively.
- rounded : bool, default=False
- When set to ``True``, draw node boxes with rounded corners and use
- Helvetica fonts instead of Times-Roman.
- precision : int, default=3
- Number of digits of precision for floating point in the values of
- impurity, threshold and value attributes of each node.
- ax : matplotlib axis, default=None
- Axes to plot to. If None, use current axis. Any previous content
- is cleared.
- fontsize : int, default=None
- Size of text font. If None, determined automatically to fit figure.
- Returns
- -------
- annotations : list of artists
- List containing the artists for the annotation boxes making up the
- tree.
- Examples
- --------
- >>> from sklearn.datasets import load_iris
- >>> from sklearn import tree
- >>> clf = tree.DecisionTreeClassifier(random_state=0)
- >>> iris = load_iris()
- >>> clf = clf.fit(iris.data, iris.target)
- >>> tree.plot_tree(clf)
- [...]
- """
- check_is_fitted(decision_tree)
- exporter = _MPLTreeExporter(
- max_depth=max_depth,
- feature_names=feature_names,
- class_names=class_names,
- label=label,
- filled=filled,
- impurity=impurity,
- node_ids=node_ids,
- proportion=proportion,
- rounded=rounded,
- precision=precision,
- fontsize=fontsize,
- )
- return exporter.export(decision_tree, ax=ax)
- class _BaseTreeExporter:
- def __init__(
- self,
- max_depth=None,
- feature_names=None,
- class_names=None,
- label="all",
- filled=False,
- impurity=True,
- node_ids=False,
- proportion=False,
- rounded=False,
- precision=3,
- fontsize=None,
- ):
- self.max_depth = max_depth
- self.feature_names = feature_names
- self.class_names = class_names
- self.label = label
- self.filled = filled
- self.impurity = impurity
- self.node_ids = node_ids
- self.proportion = proportion
- self.rounded = rounded
- self.precision = precision
- self.fontsize = fontsize
- def get_color(self, value):
- # Find the appropriate color & intensity for a node
- if self.colors["bounds"] is None:
- # Classification tree
- color = list(self.colors["rgb"][np.argmax(value)])
- sorted_values = sorted(value, reverse=True)
- if len(sorted_values) == 1:
- alpha = 0.0
- else:
- alpha = (sorted_values[0] - sorted_values[1]) / (1 - sorted_values[1])
- else:
- # Regression tree or multi-output
- color = list(self.colors["rgb"][0])
- alpha = (value - self.colors["bounds"][0]) / (
- self.colors["bounds"][1] - self.colors["bounds"][0]
- )
- # compute the color as alpha against white
- color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color]
- # Return html color code in #RRGGBB format
- return "#%2x%2x%2x" % tuple(color)
- def get_fill_color(self, tree, node_id):
- # Fetch appropriate color for node
- if "rgb" not in self.colors:
- # Initialize colors and bounds if required
- self.colors["rgb"] = _color_brew(tree.n_classes[0])
- if tree.n_outputs != 1:
- # Find max and min impurities for multi-output
- self.colors["bounds"] = (np.min(-tree.impurity), np.max(-tree.impurity))
- elif tree.n_classes[0] == 1 and len(np.unique(tree.value)) != 1:
- # Find max and min values in leaf nodes for regression
- self.colors["bounds"] = (np.min(tree.value), np.max(tree.value))
- if tree.n_outputs == 1:
- node_val = tree.value[node_id][0, :] / tree.weighted_n_node_samples[node_id]
- if tree.n_classes[0] == 1:
- # Regression or degraded classification with single class
- node_val = tree.value[node_id][0, :]
- if isinstance(node_val, Iterable) and self.colors["bounds"] is not None:
- # Only unpack the float only for the regression tree case.
- # Classification tree requires an Iterable in `get_color`.
- node_val = node_val.item()
- else:
- # If multi-output color node by impurity
- node_val = -tree.impurity[node_id]
- return self.get_color(node_val)
- def node_to_str(self, tree, node_id, criterion):
- # Generate the node content string
- if tree.n_outputs == 1:
- value = tree.value[node_id][0, :]
- else:
- value = tree.value[node_id]
- # Should labels be shown?
- labels = (self.label == "root" and node_id == 0) or self.label == "all"
- characters = self.characters
- node_string = characters[-1]
- # Write node ID
- if self.node_ids:
- if labels:
- node_string += "node "
- node_string += characters[0] + str(node_id) + characters[4]
- # Write decision criteria
- if tree.children_left[node_id] != _tree.TREE_LEAF:
- # Always write node decision criteria, except for leaves
- if self.feature_names is not None:
- feature = self.feature_names[tree.feature[node_id]]
- else:
- feature = "x%s%s%s" % (
- characters[1],
- tree.feature[node_id],
- characters[2],
- )
- node_string += "%s %s %s%s" % (
- feature,
- characters[3],
- round(tree.threshold[node_id], self.precision),
- characters[4],
- )
- # Write impurity
- if self.impurity:
- if isinstance(criterion, _criterion.FriedmanMSE):
- criterion = "friedman_mse"
- elif isinstance(criterion, _criterion.MSE) or criterion == "squared_error":
- criterion = "squared_error"
- elif not isinstance(criterion, str):
- criterion = "impurity"
- if labels:
- node_string += "%s = " % criterion
- node_string += (
- str(round(tree.impurity[node_id], self.precision)) + characters[4]
- )
- # Write node sample count
- if labels:
- node_string += "samples = "
- if self.proportion:
- percent = (
- 100.0 * tree.n_node_samples[node_id] / float(tree.n_node_samples[0])
- )
- node_string += str(round(percent, 1)) + "%" + characters[4]
- else:
- node_string += str(tree.n_node_samples[node_id]) + characters[4]
- # Write node class distribution / regression value
- if self.proportion and tree.n_classes[0] != 1:
- # For classification this will show the proportion of samples
- value = value / tree.weighted_n_node_samples[node_id]
- if labels:
- node_string += "value = "
- if tree.n_classes[0] == 1:
- # Regression
- value_text = np.around(value, self.precision)
- elif self.proportion:
- # Classification
- value_text = np.around(value, self.precision)
- elif np.all(np.equal(np.mod(value, 1), 0)):
- # Classification without floating-point weights
- value_text = value.astype(int)
- else:
- # Classification with floating-point weights
- value_text = np.around(value, self.precision)
- # Strip whitespace
- value_text = str(value_text.astype("S32")).replace("b'", "'")
- value_text = value_text.replace("' '", ", ").replace("'", "")
- if tree.n_classes[0] == 1 and tree.n_outputs == 1:
- value_text = value_text.replace("[", "").replace("]", "")
- value_text = value_text.replace("\n ", characters[4])
- node_string += value_text + characters[4]
- # Write node majority class
- if (
- self.class_names is not None
- and tree.n_classes[0] != 1
- and tree.n_outputs == 1
- ):
- # Only done for single-output classification trees
- if labels:
- node_string += "class = "
- if self.class_names is not True:
- class_name = self.class_names[np.argmax(value)]
- else:
- class_name = "y%s%s%s" % (
- characters[1],
- np.argmax(value),
- characters[2],
- )
- node_string += class_name
- # Clean up any trailing newlines
- if node_string.endswith(characters[4]):
- node_string = node_string[: -len(characters[4])]
- return node_string + characters[5]
- class _DOTTreeExporter(_BaseTreeExporter):
- def __init__(
- self,
- out_file=SENTINEL,
- max_depth=None,
- feature_names=None,
- class_names=None,
- label="all",
- filled=False,
- leaves_parallel=False,
- impurity=True,
- node_ids=False,
- proportion=False,
- rotate=False,
- rounded=False,
- special_characters=False,
- precision=3,
- fontname="helvetica",
- ):
- super().__init__(
- max_depth=max_depth,
- feature_names=feature_names,
- class_names=class_names,
- label=label,
- filled=filled,
- impurity=impurity,
- node_ids=node_ids,
- proportion=proportion,
- rounded=rounded,
- precision=precision,
- )
- self.leaves_parallel = leaves_parallel
- self.out_file = out_file
- self.special_characters = special_characters
- self.fontname = fontname
- self.rotate = rotate
- # PostScript compatibility for special characters
- if special_characters:
- self.characters = ["#", "<SUB>", "</SUB>", "≤", "<br/>", ">", "<"]
- else:
- self.characters = ["#", "[", "]", "<=", "\\n", '"', '"']
- # The depth of each node for plotting with 'leaf' option
- self.ranks = {"leaves": []}
- # The colors to render each node with
- self.colors = {"bounds": None}
- def export(self, decision_tree):
- # Check length of feature_names before getting into the tree node
- # Raise error if length of feature_names does not match
- # n_features_in_ in the decision_tree
- if self.feature_names is not None:
- if len(self.feature_names) != decision_tree.n_features_in_:
- raise ValueError(
- "Length of feature_names, %d does not match number of features, %d"
- % (len(self.feature_names), decision_tree.n_features_in_)
- )
- # each part writes to out_file
- self.head()
- # Now recurse the tree and add node & edge attributes
- if isinstance(decision_tree, _tree.Tree):
- self.recurse(decision_tree, 0, criterion="impurity")
- else:
- self.recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
- self.tail()
- def tail(self):
- # If required, draw leaf nodes at same depth as each other
- if self.leaves_parallel:
- for rank in sorted(self.ranks):
- self.out_file.write(
- "{rank=same ; " + "; ".join(r for r in self.ranks[rank]) + "} ;\n"
- )
- self.out_file.write("}")
- def head(self):
- self.out_file.write("digraph Tree {\n")
- # Specify node aesthetics
- self.out_file.write("node [shape=box")
- rounded_filled = []
- if self.filled:
- rounded_filled.append("filled")
- if self.rounded:
- rounded_filled.append("rounded")
- if len(rounded_filled) > 0:
- self.out_file.write(
- ', style="%s", color="black"' % ", ".join(rounded_filled)
- )
- self.out_file.write(', fontname="%s"' % self.fontname)
- self.out_file.write("] ;\n")
- # Specify graph & edge aesthetics
- if self.leaves_parallel:
- self.out_file.write("graph [ranksep=equally, splines=polyline] ;\n")
- self.out_file.write('edge [fontname="%s"] ;\n' % self.fontname)
- if self.rotate:
- self.out_file.write("rankdir=LR ;\n")
- def recurse(self, tree, node_id, criterion, parent=None, depth=0):
- if node_id == _tree.TREE_LEAF:
- raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
- left_child = tree.children_left[node_id]
- right_child = tree.children_right[node_id]
- # Add node with description
- if self.max_depth is None or depth <= self.max_depth:
- # Collect ranks for 'leaf' option in plot_options
- if left_child == _tree.TREE_LEAF:
- self.ranks["leaves"].append(str(node_id))
- elif str(depth) not in self.ranks:
- self.ranks[str(depth)] = [str(node_id)]
- else:
- self.ranks[str(depth)].append(str(node_id))
- self.out_file.write(
- "%d [label=%s" % (node_id, self.node_to_str(tree, node_id, criterion))
- )
- if self.filled:
- self.out_file.write(
- ', fillcolor="%s"' % self.get_fill_color(tree, node_id)
- )
- self.out_file.write("] ;\n")
- if parent is not None:
- # Add edge to parent
- self.out_file.write("%d -> %d" % (parent, node_id))
- if parent == 0:
- # Draw True/False labels if parent is root node
- angles = np.array([45, -45]) * ((self.rotate - 0.5) * -2)
- self.out_file.write(" [labeldistance=2.5, labelangle=")
- if node_id == 1:
- self.out_file.write('%d, headlabel="True"]' % angles[0])
- else:
- self.out_file.write('%d, headlabel="False"]' % angles[1])
- self.out_file.write(" ;\n")
- if left_child != _tree.TREE_LEAF:
- self.recurse(
- tree,
- left_child,
- criterion=criterion,
- parent=node_id,
- depth=depth + 1,
- )
- self.recurse(
- tree,
- right_child,
- criterion=criterion,
- parent=node_id,
- depth=depth + 1,
- )
- else:
- self.ranks["leaves"].append(str(node_id))
- self.out_file.write('%d [label="(...)"' % node_id)
- if self.filled:
- # color cropped nodes grey
- self.out_file.write(', fillcolor="#C0C0C0"')
- self.out_file.write("] ;\n" % node_id)
- if parent is not None:
- # Add edge to parent
- self.out_file.write("%d -> %d ;\n" % (parent, node_id))
- class _MPLTreeExporter(_BaseTreeExporter):
- def __init__(
- self,
- max_depth=None,
- feature_names=None,
- class_names=None,
- label="all",
- filled=False,
- impurity=True,
- node_ids=False,
- proportion=False,
- rounded=False,
- precision=3,
- fontsize=None,
- ):
- super().__init__(
- max_depth=max_depth,
- feature_names=feature_names,
- class_names=class_names,
- label=label,
- filled=filled,
- impurity=impurity,
- node_ids=node_ids,
- proportion=proportion,
- rounded=rounded,
- precision=precision,
- )
- self.fontsize = fontsize
- # The depth of each node for plotting with 'leaf' option
- self.ranks = {"leaves": []}
- # The colors to render each node with
- self.colors = {"bounds": None}
- self.characters = ["#", "[", "]", "<=", "\n", "", ""]
- self.bbox_args = dict()
- if self.rounded:
- self.bbox_args["boxstyle"] = "round"
- self.arrow_args = dict(arrowstyle="<-")
- def _make_tree(self, node_id, et, criterion, depth=0):
- # traverses _tree.Tree recursively, builds intermediate
- # "_reingold_tilford.Tree" object
- name = self.node_to_str(et, node_id, criterion=criterion)
- if et.children_left[node_id] != _tree.TREE_LEAF and (
- self.max_depth is None or depth <= self.max_depth
- ):
- children = [
- self._make_tree(
- et.children_left[node_id], et, criterion, depth=depth + 1
- ),
- self._make_tree(
- et.children_right[node_id], et, criterion, depth=depth + 1
- ),
- ]
- else:
- return Tree(name, node_id)
- return Tree(name, node_id, *children)
- def export(self, decision_tree, ax=None):
- import matplotlib.pyplot as plt
- from matplotlib.text import Annotation
- if ax is None:
- ax = plt.gca()
- ax.clear()
- ax.set_axis_off()
- my_tree = self._make_tree(0, decision_tree.tree_, decision_tree.criterion)
- draw_tree = buchheim(my_tree)
- # important to make sure we're still
- # inside the axis after drawing the box
- # this makes sense because the width of a box
- # is about the same as the distance between boxes
- max_x, max_y = draw_tree.max_extents() + 1
- ax_width = ax.get_window_extent().width
- ax_height = ax.get_window_extent().height
- scale_x = ax_width / max_x
- scale_y = ax_height / max_y
- self.recurse(draw_tree, decision_tree.tree_, ax, max_x, max_y)
- anns = [ann for ann in ax.get_children() if isinstance(ann, Annotation)]
- # update sizes of all bboxes
- renderer = ax.figure.canvas.get_renderer()
- for ann in anns:
- ann.update_bbox_position_size(renderer)
- if self.fontsize is None:
- # get figure to data transform
- # adjust fontsize to avoid overlap
- # get max box width and height
- extents = [ann.get_bbox_patch().get_window_extent() for ann in anns]
- max_width = max([extent.width for extent in extents])
- max_height = max([extent.height for extent in extents])
- # width should be around scale_x in axis coordinates
- size = anns[0].get_fontsize() * min(
- scale_x / max_width, scale_y / max_height
- )
- for ann in anns:
- ann.set_fontsize(size)
- return anns
- def recurse(self, node, tree, ax, max_x, max_y, depth=0):
- import matplotlib.pyplot as plt
- kwargs = dict(
- bbox=self.bbox_args.copy(),
- ha="center",
- va="center",
- zorder=100 - 10 * depth,
- xycoords="axes fraction",
- arrowprops=self.arrow_args.copy(),
- )
- kwargs["arrowprops"]["edgecolor"] = plt.rcParams["text.color"]
- if self.fontsize is not None:
- kwargs["fontsize"] = self.fontsize
- # offset things by .5 to center them in plot
- xy = ((node.x + 0.5) / max_x, (max_y - node.y - 0.5) / max_y)
- if self.max_depth is None or depth <= self.max_depth:
- if self.filled:
- kwargs["bbox"]["fc"] = self.get_fill_color(tree, node.tree.node_id)
- else:
- kwargs["bbox"]["fc"] = ax.get_facecolor()
- if node.parent is None:
- # root
- ax.annotate(node.tree.label, xy, **kwargs)
- else:
- xy_parent = (
- (node.parent.x + 0.5) / max_x,
- (max_y - node.parent.y - 0.5) / max_y,
- )
- ax.annotate(node.tree.label, xy_parent, xy, **kwargs)
- for child in node.children:
- self.recurse(child, tree, ax, max_x, max_y, depth=depth + 1)
- else:
- xy_parent = (
- (node.parent.x + 0.5) / max_x,
- (max_y - node.parent.y - 0.5) / max_y,
- )
- kwargs["bbox"]["fc"] = "grey"
- ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
- @validate_params(
- {
- "decision_tree": "no_validation",
- "out_file": [str, None, HasMethods("write")],
- "max_depth": [Interval(Integral, 0, None, closed="left"), None],
- "feature_names": ["array-like", None],
- "class_names": ["array-like", "boolean", None],
- "label": [StrOptions({"all", "root", "none"})],
- "filled": ["boolean"],
- "leaves_parallel": ["boolean"],
- "impurity": ["boolean"],
- "node_ids": ["boolean"],
- "proportion": ["boolean"],
- "rotate": ["boolean"],
- "rounded": ["boolean"],
- "special_characters": ["boolean"],
- "precision": [Interval(Integral, 0, None, closed="left"), None],
- "fontname": [str],
- },
- prefer_skip_nested_validation=True,
- )
- def export_graphviz(
- decision_tree,
- out_file=None,
- *,
- max_depth=None,
- feature_names=None,
- class_names=None,
- label="all",
- filled=False,
- leaves_parallel=False,
- impurity=True,
- node_ids=False,
- proportion=False,
- rotate=False,
- rounded=False,
- special_characters=False,
- precision=3,
- fontname="helvetica",
- ):
- """Export a decision tree in DOT format.
- This function generates a GraphViz representation of the decision tree,
- which is then written into `out_file`. Once exported, graphical renderings
- can be generated using, for example::
- $ dot -Tps tree.dot -o tree.ps (PostScript format)
- $ dot -Tpng tree.dot -o tree.png (PNG format)
- The sample counts that are shown are weighted with any sample_weights that
- might be present.
- Read more in the :ref:`User Guide <tree>`.
- Parameters
- ----------
- decision_tree : object
- The decision tree estimator to be exported to GraphViz.
- out_file : object or str, default=None
- Handle or name of the output file. If ``None``, the result is
- returned as a string.
- .. versionchanged:: 0.20
- Default of out_file changed from "tree.dot" to None.
- max_depth : int, default=None
- The maximum depth of the representation. If None, the tree is fully
- generated.
- feature_names : array-like of shape (n_features,), default=None
- An array containing the feature names.
- If None, generic names will be used ("x[0]", "x[1]", ...).
- class_names : array-like of shape (n_classes,) or bool, default=None
- Names of each of the target classes in ascending numerical order.
- Only relevant for classification and not supported for multi-output.
- If ``True``, shows a symbolic representation of the class name.
- label : {'all', 'root', 'none'}, default='all'
- Whether to show informative labels for impurity, etc.
- Options include 'all' to show at every node, 'root' to show only at
- the top root node, or 'none' to not show at any node.
- filled : bool, default=False
- When set to ``True``, paint nodes to indicate majority class for
- classification, extremity of values for regression, or purity of node
- for multi-output.
- leaves_parallel : bool, default=False
- When set to ``True``, draw all leaf nodes at the bottom of the tree.
- impurity : bool, default=True
- When set to ``True``, show the impurity at each node.
- node_ids : bool, default=False
- When set to ``True``, show the ID number on each node.
- proportion : bool, default=False
- When set to ``True``, change the display of 'values' and/or 'samples'
- to be proportions and percentages respectively.
- rotate : bool, default=False
- When set to ``True``, orient tree left to right rather than top-down.
- rounded : bool, default=False
- When set to ``True``, draw node boxes with rounded corners.
- special_characters : bool, default=False
- When set to ``False``, ignore special characters for PostScript
- compatibility.
- precision : int, default=3
- Number of digits of precision for floating point in the values of
- impurity, threshold and value attributes of each node.
- fontname : str, default='helvetica'
- Name of font used to render text.
- Returns
- -------
- dot_data : str
- String representation of the input tree in GraphViz dot format.
- Only returned if ``out_file`` is None.
- .. versionadded:: 0.18
- Examples
- --------
- >>> from sklearn.datasets import load_iris
- >>> from sklearn import tree
- >>> clf = tree.DecisionTreeClassifier()
- >>> iris = load_iris()
- >>> clf = clf.fit(iris.data, iris.target)
- >>> tree.export_graphviz(clf)
- 'digraph Tree {...
- """
- if feature_names is not None:
- feature_names = check_array(
- feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
- )
- if class_names is not None and not isinstance(class_names, bool):
- class_names = check_array(
- class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
- )
- check_is_fitted(decision_tree)
- own_file = False
- return_string = False
- try:
- if isinstance(out_file, str):
- out_file = open(out_file, "w", encoding="utf-8")
- own_file = True
- if out_file is None:
- return_string = True
- out_file = StringIO()
- exporter = _DOTTreeExporter(
- out_file=out_file,
- max_depth=max_depth,
- feature_names=feature_names,
- class_names=class_names,
- label=label,
- filled=filled,
- leaves_parallel=leaves_parallel,
- impurity=impurity,
- node_ids=node_ids,
- proportion=proportion,
- rotate=rotate,
- rounded=rounded,
- special_characters=special_characters,
- precision=precision,
- fontname=fontname,
- )
- exporter.export(decision_tree)
- if return_string:
- return exporter.out_file.getvalue()
- finally:
- if own_file:
- out_file.close()
- def _compute_depth(tree, node):
- """
- Returns the depth of the subtree rooted in node.
- """
- def compute_depth_(
- current_node, current_depth, children_left, children_right, depths
- ):
- depths += [current_depth]
- left = children_left[current_node]
- right = children_right[current_node]
- if left != -1 and right != -1:
- compute_depth_(
- left, current_depth + 1, children_left, children_right, depths
- )
- compute_depth_(
- right, current_depth + 1, children_left, children_right, depths
- )
- depths = []
- compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
- return max(depths)
- @validate_params(
- {
- "decision_tree": [DecisionTreeClassifier, DecisionTreeRegressor],
- "feature_names": ["array-like", None],
- "class_names": ["array-like", None],
- "max_depth": [Interval(Integral, 0, None, closed="left"), None],
- "spacing": [Interval(Integral, 1, None, closed="left"), None],
- "decimals": [Interval(Integral, 0, None, closed="left"), None],
- "show_weights": ["boolean"],
- },
- prefer_skip_nested_validation=True,
- )
- def export_text(
- decision_tree,
- *,
- feature_names=None,
- class_names=None,
- max_depth=10,
- spacing=3,
- decimals=2,
- show_weights=False,
- ):
- """Build a text report showing the rules of a decision tree.
- Note that backwards compatibility may not be supported.
- Parameters
- ----------
- decision_tree : object
- The decision tree estimator to be exported.
- It can be an instance of
- DecisionTreeClassifier or DecisionTreeRegressor.
- feature_names : array-like of shape (n_features,), default=None
- An array containing the feature names.
- If None generic names will be used ("feature_0", "feature_1", ...).
- class_names : array-like of shape (n_classes,), default=None
- Names of each of the target classes in ascending numerical order.
- Only relevant for classification and not supported for multi-output.
- - if `None`, the class names are delegated to `decision_tree.classes_`;
- - otherwise, `class_names` will be used as class names instead of
- `decision_tree.classes_`. The length of `class_names` must match
- the length of `decision_tree.classes_`.
- .. versionadded:: 1.3
- max_depth : int, default=10
- Only the first max_depth levels of the tree are exported.
- Truncated branches will be marked with "...".
- spacing : int, default=3
- Number of spaces between edges. The higher it is, the wider the result.
- decimals : int, default=2
- Number of decimal digits to display.
- show_weights : bool, default=False
- If true the classification weights will be exported on each leaf.
- The classification weights are the number of samples each class.
- Returns
- -------
- report : str
- Text summary of all the rules in the decision tree.
- Examples
- --------
- >>> from sklearn.datasets import load_iris
- >>> from sklearn.tree import DecisionTreeClassifier
- >>> from sklearn.tree import export_text
- >>> iris = load_iris()
- >>> X = iris['data']
- >>> y = iris['target']
- >>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
- >>> decision_tree = decision_tree.fit(X, y)
- >>> r = export_text(decision_tree, feature_names=iris['feature_names'])
- >>> print(r)
- |--- petal width (cm) <= 0.80
- | |--- class: 0
- |--- petal width (cm) > 0.80
- | |--- petal width (cm) <= 1.75
- | | |--- class: 1
- | |--- petal width (cm) > 1.75
- | | |--- class: 2
- """
- if feature_names is not None:
- feature_names = check_array(
- feature_names, ensure_2d=False, dtype=None, ensure_min_samples=0
- )
- if class_names is not None:
- class_names = check_array(
- class_names, ensure_2d=False, dtype=None, ensure_min_samples=0
- )
- check_is_fitted(decision_tree)
- tree_ = decision_tree.tree_
- if is_classifier(decision_tree):
- if class_names is None:
- class_names = decision_tree.classes_
- elif len(class_names) != len(decision_tree.classes_):
- raise ValueError(
- "When `class_names` is an array, it should contain as"
- " many items as `decision_tree.classes_`. Got"
- f" {len(class_names)} while the tree was fitted with"
- f" {len(decision_tree.classes_)} classes."
- )
- right_child_fmt = "{} {} <= {}\n"
- left_child_fmt = "{} {} > {}\n"
- truncation_fmt = "{} {}\n"
- if feature_names is not None and len(feature_names) != tree_.n_features:
- raise ValueError(
- "feature_names must contain %d elements, got %d"
- % (tree_.n_features, len(feature_names))
- )
- if isinstance(decision_tree, DecisionTreeClassifier):
- value_fmt = "{}{} weights: {}\n"
- if not show_weights:
- value_fmt = "{}{}{}\n"
- else:
- value_fmt = "{}{} value: {}\n"
- if feature_names is not None:
- feature_names_ = [
- feature_names[i] if i != _tree.TREE_UNDEFINED else None
- for i in tree_.feature
- ]
- else:
- feature_names_ = ["feature_{}".format(i) for i in tree_.feature]
- export_text.report = ""
- def _add_leaf(value, class_name, indent):
- val = ""
- is_classification = isinstance(decision_tree, DecisionTreeClassifier)
- if show_weights or not is_classification:
- val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
- val = "[" + "".join(val)[:-2] + "]"
- if is_classification:
- val += " class: " + str(class_name)
- export_text.report += value_fmt.format(indent, "", val)
- def print_tree_recurse(node, depth):
- indent = ("|" + (" " * spacing)) * depth
- indent = indent[:-spacing] + "-" * spacing
- value = None
- if tree_.n_outputs == 1:
- value = tree_.value[node][0]
- else:
- value = tree_.value[node].T[0]
- class_name = np.argmax(value)
- if tree_.n_classes[0] != 1 and tree_.n_outputs == 1:
- class_name = class_names[class_name]
- if depth <= max_depth + 1:
- info_fmt = ""
- info_fmt_left = info_fmt
- info_fmt_right = info_fmt
- if tree_.feature[node] != _tree.TREE_UNDEFINED:
- name = feature_names_[node]
- threshold = tree_.threshold[node]
- threshold = "{1:.{0}f}".format(decimals, threshold)
- export_text.report += right_child_fmt.format(indent, name, threshold)
- export_text.report += info_fmt_left
- print_tree_recurse(tree_.children_left[node], depth + 1)
- export_text.report += left_child_fmt.format(indent, name, threshold)
- export_text.report += info_fmt_right
- print_tree_recurse(tree_.children_right[node], depth + 1)
- else: # leaf
- _add_leaf(value, class_name, indent)
- else:
- subtree_depth = _compute_depth(tree_, node)
- if subtree_depth == 1:
- _add_leaf(value, class_name, indent)
- else:
- trunc_report = "truncated branch of depth %d" % subtree_depth
- export_text.report += truncation_fmt.format(indent, trunc_report)
- print_tree_recurse(0, 1)
- return export_text.report
|