grower.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785
  1. """
  2. This module contains the TreeGrower class.
  3. TreeGrower builds a regression tree fitting a Newton-Raphson step, based on
  4. the gradients and hessians of the training data.
  5. """
  6. # Author: Nicolas Hug
  7. import numbers
  8. from heapq import heappop, heappush
  9. from timeit import default_timer as time
  10. import numpy as np
  11. from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
  12. from ._bitset import set_raw_bitset_from_binned_bitset
  13. from .common import (
  14. PREDICTOR_RECORD_DTYPE,
  15. X_BITSET_INNER_DTYPE,
  16. Y_DTYPE,
  17. MonotonicConstraint,
  18. )
  19. from .histogram import HistogramBuilder
  20. from .predictor import TreePredictor
  21. from .splitting import Splitter
  22. from .utils import sum_parallel
  23. EPS = np.finfo(Y_DTYPE).eps # to avoid zero division errors
  24. class TreeNode:
  25. """Tree Node class used in TreeGrower.
  26. This isn't used for prediction purposes, only for training (see
  27. TreePredictor).
  28. Parameters
  29. ----------
  30. depth : int
  31. The depth of the node, i.e. its distance from the root.
  32. sample_indices : ndarray of shape (n_samples_at_node,), dtype=np.uint
  33. The indices of the samples at the node.
  34. sum_gradients : float
  35. The sum of the gradients of the samples at the node.
  36. sum_hessians : float
  37. The sum of the hessians of the samples at the node.
  38. Attributes
  39. ----------
  40. depth : int
  41. The depth of the node, i.e. its distance from the root.
  42. sample_indices : ndarray of shape (n_samples_at_node,), dtype=np.uint
  43. The indices of the samples at the node.
  44. sum_gradients : float
  45. The sum of the gradients of the samples at the node.
  46. sum_hessians : float
  47. The sum of the hessians of the samples at the node.
  48. split_info : SplitInfo or None
  49. The result of the split evaluation.
  50. is_leaf : bool
  51. True if node is a leaf
  52. left_child : TreeNode or None
  53. The left child of the node. None for leaves.
  54. right_child : TreeNode or None
  55. The right child of the node. None for leaves.
  56. value : float or None
  57. The value of the leaf, as computed in finalize_leaf(). None for
  58. non-leaf nodes.
  59. partition_start : int
  60. start position of the node's sample_indices in splitter.partition.
  61. partition_stop : int
  62. stop position of the node's sample_indices in splitter.partition.
  63. allowed_features : None or ndarray, dtype=int
  64. Indices of features allowed to split for children.
  65. interaction_cst_indices : None or list of ints
  66. Indices of the interaction sets that have to be applied on splits of
  67. child nodes. The fewer sets the stronger the constraint as fewer sets
  68. contain fewer features.
  69. children_lower_bound : float
  70. children_upper_bound : float
  71. """
  72. split_info = None
  73. left_child = None
  74. right_child = None
  75. histograms = None
  76. # start and stop indices of the node in the splitter.partition
  77. # array. Concretely,
  78. # self.sample_indices = view(self.splitter.partition[start:stop])
  79. # Please see the comments about splitter.partition and
  80. # splitter.split_indices for more info about this design.
  81. # These 2 attributes are only used in _update_raw_prediction, because we
  82. # need to iterate over the leaves and I don't know how to efficiently
  83. # store the sample_indices views because they're all of different sizes.
  84. partition_start = 0
  85. partition_stop = 0
  86. def __init__(self, depth, sample_indices, sum_gradients, sum_hessians, value=None):
  87. self.depth = depth
  88. self.sample_indices = sample_indices
  89. self.n_samples = sample_indices.shape[0]
  90. self.sum_gradients = sum_gradients
  91. self.sum_hessians = sum_hessians
  92. self.value = value
  93. self.is_leaf = False
  94. self.allowed_features = None
  95. self.interaction_cst_indices = None
  96. self.set_children_bounds(float("-inf"), float("+inf"))
  97. def set_children_bounds(self, lower, upper):
  98. """Set children values bounds to respect monotonic constraints."""
  99. # These are bounds for the node's *children* values, not the node's
  100. # value. The bounds are used in the splitter when considering potential
  101. # left and right child.
  102. self.children_lower_bound = lower
  103. self.children_upper_bound = upper
  104. def __lt__(self, other_node):
  105. """Comparison for priority queue.
  106. Nodes with high gain are higher priority than nodes with low gain.
  107. heapq.heappush only need the '<' operator.
  108. heapq.heappop take the smallest item first (smaller is higher
  109. priority).
  110. Parameters
  111. ----------
  112. other_node : TreeNode
  113. The node to compare with.
  114. """
  115. return self.split_info.gain > other_node.split_info.gain
  116. class TreeGrower:
  117. """Tree grower class used to build a tree.
  118. The tree is fitted to predict the values of a Newton-Raphson step. The
  119. splits are considered in a best-first fashion, and the quality of a
  120. split is defined in splitting._split_gain.
  121. Parameters
  122. ----------
  123. X_binned : ndarray of shape (n_samples, n_features), dtype=np.uint8
  124. The binned input samples. Must be Fortran-aligned.
  125. gradients : ndarray of shape (n_samples,)
  126. The gradients of each training sample. Those are the gradients of the
  127. loss w.r.t the predictions, evaluated at iteration ``i - 1``.
  128. hessians : ndarray of shape (n_samples,)
  129. The hessians of each training sample. Those are the hessians of the
  130. loss w.r.t the predictions, evaluated at iteration ``i - 1``.
  131. max_leaf_nodes : int, default=None
  132. The maximum number of leaves for each tree. If None, there is no
  133. maximum limit.
  134. max_depth : int, default=None
  135. The maximum depth of each tree. The depth of a tree is the number of
  136. edges to go from the root to the deepest leaf.
  137. Depth isn't constrained by default.
  138. min_samples_leaf : int, default=20
  139. The minimum number of samples per leaf.
  140. min_gain_to_split : float, default=0.
  141. The minimum gain needed to split a node. Splits with lower gain will
  142. be ignored.
  143. n_bins : int, default=256
  144. The total number of bins, including the bin for missing values. Used
  145. to define the shape of the histograms.
  146. n_bins_non_missing : ndarray, dtype=np.uint32, default=None
  147. For each feature, gives the number of bins actually used for
  148. non-missing values. For features with a lot of unique values, this
  149. is equal to ``n_bins - 1``. If it's an int, all features are
  150. considered to have the same number of bins. If None, all features
  151. are considered to have ``n_bins - 1`` bins.
  152. has_missing_values : bool or ndarray, dtype=bool, default=False
  153. Whether each feature contains missing values (in the training data).
  154. If it's a bool, the same value is used for all features.
  155. is_categorical : ndarray of bool of shape (n_features,), default=None
  156. Indicates categorical features.
  157. monotonic_cst : array-like of int of shape (n_features,), dtype=int, default=None
  158. Indicates the monotonic constraint to enforce on each feature.
  159. - 1: monotonic increase
  160. - 0: no constraint
  161. - -1: monotonic decrease
  162. Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
  163. interaction_cst : list of sets of integers, default=None
  164. List of interaction constraints.
  165. l2_regularization : float, default=0.
  166. The L2 regularization parameter.
  167. min_hessian_to_split : float, default=1e-3
  168. The minimum sum of hessians needed in each node. Splits that result in
  169. at least one child having a sum of hessians less than
  170. ``min_hessian_to_split`` are discarded.
  171. shrinkage : float, default=1.
  172. The shrinkage parameter to apply to the leaves values, also known as
  173. learning rate.
  174. n_threads : int, default=None
  175. Number of OpenMP threads to use. `_openmp_effective_n_threads` is called
  176. to determine the effective number of threads use, which takes cgroups CPU
  177. quotes into account. See the docstring of `_openmp_effective_n_threads`
  178. for details.
  179. Attributes
  180. ----------
  181. histogram_builder : HistogramBuilder
  182. splitter : Splitter
  183. root : TreeNode
  184. finalized_leaves : list of TreeNode
  185. splittable_nodes : list of TreeNode
  186. missing_values_bin_idx : int
  187. Equals n_bins - 1
  188. n_categorical_splits : int
  189. n_features : int
  190. n_nodes : int
  191. total_find_split_time : float
  192. Time spent finding the best splits
  193. total_compute_hist_time : float
  194. Time spent computing histograms
  195. total_apply_split_time : float
  196. Time spent splitting nodes
  197. with_monotonic_cst : bool
  198. Whether there are monotonic constraints that apply. False iff monotonic_cst is
  199. None.
  200. """
  201. def __init__(
  202. self,
  203. X_binned,
  204. gradients,
  205. hessians,
  206. max_leaf_nodes=None,
  207. max_depth=None,
  208. min_samples_leaf=20,
  209. min_gain_to_split=0.0,
  210. n_bins=256,
  211. n_bins_non_missing=None,
  212. has_missing_values=False,
  213. is_categorical=None,
  214. monotonic_cst=None,
  215. interaction_cst=None,
  216. l2_regularization=0.0,
  217. min_hessian_to_split=1e-3,
  218. shrinkage=1.0,
  219. n_threads=None,
  220. ):
  221. self._validate_parameters(
  222. X_binned,
  223. min_gain_to_split,
  224. min_hessian_to_split,
  225. )
  226. n_threads = _openmp_effective_n_threads(n_threads)
  227. if n_bins_non_missing is None:
  228. n_bins_non_missing = n_bins - 1
  229. if isinstance(n_bins_non_missing, numbers.Integral):
  230. n_bins_non_missing = np.array(
  231. [n_bins_non_missing] * X_binned.shape[1], dtype=np.uint32
  232. )
  233. else:
  234. n_bins_non_missing = np.asarray(n_bins_non_missing, dtype=np.uint32)
  235. if isinstance(has_missing_values, bool):
  236. has_missing_values = [has_missing_values] * X_binned.shape[1]
  237. has_missing_values = np.asarray(has_missing_values, dtype=np.uint8)
  238. # `monotonic_cst` validation is done in _validate_monotonic_cst
  239. # at the estimator level and therefore the following should not be
  240. # needed when using the public API.
  241. if monotonic_cst is None:
  242. monotonic_cst = np.full(
  243. shape=X_binned.shape[1],
  244. fill_value=MonotonicConstraint.NO_CST,
  245. dtype=np.int8,
  246. )
  247. else:
  248. monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8)
  249. self.with_monotonic_cst = np.any(monotonic_cst != MonotonicConstraint.NO_CST)
  250. if is_categorical is None:
  251. is_categorical = np.zeros(shape=X_binned.shape[1], dtype=np.uint8)
  252. else:
  253. is_categorical = np.asarray(is_categorical, dtype=np.uint8)
  254. if np.any(
  255. np.logical_and(
  256. is_categorical == 1, monotonic_cst != MonotonicConstraint.NO_CST
  257. )
  258. ):
  259. raise ValueError("Categorical features cannot have monotonic constraints.")
  260. hessians_are_constant = hessians.shape[0] == 1
  261. self.histogram_builder = HistogramBuilder(
  262. X_binned, n_bins, gradients, hessians, hessians_are_constant, n_threads
  263. )
  264. missing_values_bin_idx = n_bins - 1
  265. self.splitter = Splitter(
  266. X_binned,
  267. n_bins_non_missing,
  268. missing_values_bin_idx,
  269. has_missing_values,
  270. is_categorical,
  271. monotonic_cst,
  272. l2_regularization,
  273. min_hessian_to_split,
  274. min_samples_leaf,
  275. min_gain_to_split,
  276. hessians_are_constant,
  277. n_threads,
  278. )
  279. self.n_bins_non_missing = n_bins_non_missing
  280. self.missing_values_bin_idx = missing_values_bin_idx
  281. self.max_leaf_nodes = max_leaf_nodes
  282. self.has_missing_values = has_missing_values
  283. self.monotonic_cst = monotonic_cst
  284. self.interaction_cst = interaction_cst
  285. self.is_categorical = is_categorical
  286. self.l2_regularization = l2_regularization
  287. self.n_features = X_binned.shape[1]
  288. self.max_depth = max_depth
  289. self.min_samples_leaf = min_samples_leaf
  290. self.X_binned = X_binned
  291. self.min_gain_to_split = min_gain_to_split
  292. self.shrinkage = shrinkage
  293. self.n_threads = n_threads
  294. self.splittable_nodes = []
  295. self.finalized_leaves = []
  296. self.total_find_split_time = 0.0 # time spent finding the best splits
  297. self.total_compute_hist_time = 0.0 # time spent computing histograms
  298. self.total_apply_split_time = 0.0 # time spent splitting nodes
  299. self.n_categorical_splits = 0
  300. self._intilialize_root(gradients, hessians, hessians_are_constant)
  301. self.n_nodes = 1
  302. def _validate_parameters(
  303. self,
  304. X_binned,
  305. min_gain_to_split,
  306. min_hessian_to_split,
  307. ):
  308. """Validate parameters passed to __init__.
  309. Also validate parameters passed to splitter.
  310. """
  311. if X_binned.dtype != np.uint8:
  312. raise NotImplementedError("X_binned must be of type uint8.")
  313. if not X_binned.flags.f_contiguous:
  314. raise ValueError(
  315. "X_binned should be passed as Fortran contiguous "
  316. "array for maximum efficiency."
  317. )
  318. if min_gain_to_split < 0:
  319. raise ValueError(
  320. "min_gain_to_split={} must be positive.".format(min_gain_to_split)
  321. )
  322. if min_hessian_to_split < 0:
  323. raise ValueError(
  324. "min_hessian_to_split={} must be positive.".format(min_hessian_to_split)
  325. )
  326. def grow(self):
  327. """Grow the tree, from root to leaves."""
  328. while self.splittable_nodes:
  329. self.split_next()
  330. self._apply_shrinkage()
  331. def _apply_shrinkage(self):
  332. """Multiply leaves values by shrinkage parameter.
  333. This must be done at the very end of the growing process. If this were
  334. done during the growing process e.g. in finalize_leaf(), then a leaf
  335. would be shrunk but its sibling would potentially not be (if it's a
  336. non-leaf), which would lead to a wrong computation of the 'middle'
  337. value needed to enforce the monotonic constraints.
  338. """
  339. for leaf in self.finalized_leaves:
  340. leaf.value *= self.shrinkage
  341. def _intilialize_root(self, gradients, hessians, hessians_are_constant):
  342. """Initialize root node and finalize it if needed."""
  343. n_samples = self.X_binned.shape[0]
  344. depth = 0
  345. sum_gradients = sum_parallel(gradients, self.n_threads)
  346. if self.histogram_builder.hessians_are_constant:
  347. sum_hessians = hessians[0] * n_samples
  348. else:
  349. sum_hessians = sum_parallel(hessians, self.n_threads)
  350. self.root = TreeNode(
  351. depth=depth,
  352. sample_indices=self.splitter.partition,
  353. sum_gradients=sum_gradients,
  354. sum_hessians=sum_hessians,
  355. value=0,
  356. )
  357. self.root.partition_start = 0
  358. self.root.partition_stop = n_samples
  359. if self.root.n_samples < 2 * self.min_samples_leaf:
  360. # Do not even bother computing any splitting statistics.
  361. self._finalize_leaf(self.root)
  362. return
  363. if sum_hessians < self.splitter.min_hessian_to_split:
  364. self._finalize_leaf(self.root)
  365. return
  366. if self.interaction_cst is not None:
  367. self.root.interaction_cst_indices = range(len(self.interaction_cst))
  368. allowed_features = set().union(*self.interaction_cst)
  369. self.root.allowed_features = np.fromiter(
  370. allowed_features, dtype=np.uint32, count=len(allowed_features)
  371. )
  372. tic = time()
  373. self.root.histograms = self.histogram_builder.compute_histograms_brute(
  374. self.root.sample_indices, self.root.allowed_features
  375. )
  376. self.total_compute_hist_time += time() - tic
  377. tic = time()
  378. self._compute_best_split_and_push(self.root)
  379. self.total_find_split_time += time() - tic
  380. def _compute_best_split_and_push(self, node):
  381. """Compute the best possible split (SplitInfo) of a given node.
  382. Also push it in the heap of splittable nodes if gain isn't zero.
  383. The gain of a node is 0 if either all the leaves are pure
  384. (best gain = 0), or if no split would satisfy the constraints,
  385. (min_hessians_to_split, min_gain_to_split, min_samples_leaf)
  386. """
  387. node.split_info = self.splitter.find_node_split(
  388. n_samples=node.n_samples,
  389. histograms=node.histograms,
  390. sum_gradients=node.sum_gradients,
  391. sum_hessians=node.sum_hessians,
  392. value=node.value,
  393. lower_bound=node.children_lower_bound,
  394. upper_bound=node.children_upper_bound,
  395. allowed_features=node.allowed_features,
  396. )
  397. if node.split_info.gain <= 0: # no valid split
  398. self._finalize_leaf(node)
  399. else:
  400. heappush(self.splittable_nodes, node)
  401. def split_next(self):
  402. """Split the node with highest potential gain.
  403. Returns
  404. -------
  405. left : TreeNode
  406. The resulting left child.
  407. right : TreeNode
  408. The resulting right child.
  409. """
  410. # Consider the node with the highest loss reduction (a.k.a. gain)
  411. node = heappop(self.splittable_nodes)
  412. tic = time()
  413. (
  414. sample_indices_left,
  415. sample_indices_right,
  416. right_child_pos,
  417. ) = self.splitter.split_indices(node.split_info, node.sample_indices)
  418. self.total_apply_split_time += time() - tic
  419. depth = node.depth + 1
  420. n_leaf_nodes = len(self.finalized_leaves) + len(self.splittable_nodes)
  421. n_leaf_nodes += 2
  422. left_child_node = TreeNode(
  423. depth,
  424. sample_indices_left,
  425. node.split_info.sum_gradient_left,
  426. node.split_info.sum_hessian_left,
  427. value=node.split_info.value_left,
  428. )
  429. right_child_node = TreeNode(
  430. depth,
  431. sample_indices_right,
  432. node.split_info.sum_gradient_right,
  433. node.split_info.sum_hessian_right,
  434. value=node.split_info.value_right,
  435. )
  436. node.right_child = right_child_node
  437. node.left_child = left_child_node
  438. # set start and stop indices
  439. left_child_node.partition_start = node.partition_start
  440. left_child_node.partition_stop = node.partition_start + right_child_pos
  441. right_child_node.partition_start = left_child_node.partition_stop
  442. right_child_node.partition_stop = node.partition_stop
  443. # set interaction constraints (the indices of the constraints sets)
  444. if self.interaction_cst is not None:
  445. # Calculate allowed_features and interaction_cst_indices only once. Child
  446. # nodes inherit them before they get split.
  447. (
  448. left_child_node.allowed_features,
  449. left_child_node.interaction_cst_indices,
  450. ) = self._compute_interactions(node)
  451. right_child_node.interaction_cst_indices = (
  452. left_child_node.interaction_cst_indices
  453. )
  454. right_child_node.allowed_features = left_child_node.allowed_features
  455. if not self.has_missing_values[node.split_info.feature_idx]:
  456. # If no missing values are encountered at fit time, then samples
  457. # with missing values during predict() will go to whichever child
  458. # has the most samples.
  459. node.split_info.missing_go_to_left = (
  460. left_child_node.n_samples > right_child_node.n_samples
  461. )
  462. self.n_nodes += 2
  463. self.n_categorical_splits += node.split_info.is_categorical
  464. if self.max_leaf_nodes is not None and n_leaf_nodes == self.max_leaf_nodes:
  465. self._finalize_leaf(left_child_node)
  466. self._finalize_leaf(right_child_node)
  467. self._finalize_splittable_nodes()
  468. return left_child_node, right_child_node
  469. if self.max_depth is not None and depth == self.max_depth:
  470. self._finalize_leaf(left_child_node)
  471. self._finalize_leaf(right_child_node)
  472. return left_child_node, right_child_node
  473. if left_child_node.n_samples < self.min_samples_leaf * 2:
  474. self._finalize_leaf(left_child_node)
  475. if right_child_node.n_samples < self.min_samples_leaf * 2:
  476. self._finalize_leaf(right_child_node)
  477. if self.with_monotonic_cst:
  478. # Set value bounds for respecting monotonic constraints
  479. # See test_nodes_values() for details
  480. if (
  481. self.monotonic_cst[node.split_info.feature_idx]
  482. == MonotonicConstraint.NO_CST
  483. ):
  484. lower_left = lower_right = node.children_lower_bound
  485. upper_left = upper_right = node.children_upper_bound
  486. else:
  487. mid = (left_child_node.value + right_child_node.value) / 2
  488. if (
  489. self.monotonic_cst[node.split_info.feature_idx]
  490. == MonotonicConstraint.POS
  491. ):
  492. lower_left, upper_left = node.children_lower_bound, mid
  493. lower_right, upper_right = mid, node.children_upper_bound
  494. else: # NEG
  495. lower_left, upper_left = mid, node.children_upper_bound
  496. lower_right, upper_right = node.children_lower_bound, mid
  497. left_child_node.set_children_bounds(lower_left, upper_left)
  498. right_child_node.set_children_bounds(lower_right, upper_right)
  499. # Compute histograms of children, and compute their best possible split
  500. # (if needed)
  501. should_split_left = not left_child_node.is_leaf
  502. should_split_right = not right_child_node.is_leaf
  503. if should_split_left or should_split_right:
  504. # We will compute the histograms of both nodes even if one of them
  505. # is a leaf, since computing the second histogram is very cheap
  506. # (using histogram subtraction).
  507. n_samples_left = left_child_node.sample_indices.shape[0]
  508. n_samples_right = right_child_node.sample_indices.shape[0]
  509. if n_samples_left < n_samples_right:
  510. smallest_child = left_child_node
  511. largest_child = right_child_node
  512. else:
  513. smallest_child = right_child_node
  514. largest_child = left_child_node
  515. # We use the brute O(n_samples) method on the child that has the
  516. # smallest number of samples, and the subtraction trick O(n_bins)
  517. # on the other one.
  518. # Note that both left and right child have the same allowed_features.
  519. tic = time()
  520. smallest_child.histograms = self.histogram_builder.compute_histograms_brute(
  521. smallest_child.sample_indices, smallest_child.allowed_features
  522. )
  523. largest_child.histograms = (
  524. self.histogram_builder.compute_histograms_subtraction(
  525. node.histograms,
  526. smallest_child.histograms,
  527. smallest_child.allowed_features,
  528. )
  529. )
  530. self.total_compute_hist_time += time() - tic
  531. tic = time()
  532. if should_split_left:
  533. self._compute_best_split_and_push(left_child_node)
  534. if should_split_right:
  535. self._compute_best_split_and_push(right_child_node)
  536. self.total_find_split_time += time() - tic
  537. # Release memory used by histograms as they are no longer needed
  538. # for leaf nodes since they won't be split.
  539. for child in (left_child_node, right_child_node):
  540. if child.is_leaf:
  541. del child.histograms
  542. # Release memory used by histograms as they are no longer needed for
  543. # internal nodes once children histograms have been computed.
  544. del node.histograms
  545. return left_child_node, right_child_node
  546. def _compute_interactions(self, node):
  547. r"""Compute features allowed by interactions to be inherited by child nodes.
  548. Example: Assume constraints [{0, 1}, {1, 2}].
  549. 1 <- Both constraint groups could be applied from now on
  550. / \
  551. 1 2 <- Left split still fulfills both constraint groups.
  552. / \ / \ Right split at feature 2 has only group {1, 2} from now on.
  553. LightGBM uses the same logic for overlapping groups. See
  554. https://github.com/microsoft/LightGBM/issues/4481 for details.
  555. Parameters:
  556. ----------
  557. node : TreeNode
  558. A node that might have children. Based on its feature_idx, the interaction
  559. constraints for possible child nodes are computed.
  560. Returns
  561. -------
  562. allowed_features : ndarray, dtype=uint32
  563. Indices of features allowed to split for children.
  564. interaction_cst_indices : list of ints
  565. Indices of the interaction sets that have to be applied on splits of
  566. child nodes. The fewer sets the stronger the constraint as fewer sets
  567. contain fewer features.
  568. """
  569. # Note:
  570. # - Case of no interactions is already captured before function call.
  571. # - This is for nodes that are already split and have a
  572. # node.split_info.feature_idx.
  573. allowed_features = set()
  574. interaction_cst_indices = []
  575. for i in node.interaction_cst_indices:
  576. if node.split_info.feature_idx in self.interaction_cst[i]:
  577. interaction_cst_indices.append(i)
  578. allowed_features.update(self.interaction_cst[i])
  579. return (
  580. np.fromiter(allowed_features, dtype=np.uint32, count=len(allowed_features)),
  581. interaction_cst_indices,
  582. )
  583. def _finalize_leaf(self, node):
  584. """Make node a leaf of the tree being grown."""
  585. node.is_leaf = True
  586. self.finalized_leaves.append(node)
  587. def _finalize_splittable_nodes(self):
  588. """Transform all splittable nodes into leaves.
  589. Used when some constraint is met e.g. maximum number of leaves or
  590. maximum depth."""
  591. while len(self.splittable_nodes) > 0:
  592. node = self.splittable_nodes.pop()
  593. self._finalize_leaf(node)
  594. def make_predictor(self, binning_thresholds):
  595. """Make a TreePredictor object out of the current tree.
  596. Parameters
  597. ----------
  598. binning_thresholds : array-like of floats
  599. Corresponds to the bin_thresholds_ attribute of the BinMapper.
  600. For each feature, this stores:
  601. - the bin frontiers for continuous features
  602. - the unique raw category values for categorical features
  603. Returns
  604. -------
  605. A TreePredictor object.
  606. """
  607. predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE)
  608. binned_left_cat_bitsets = np.zeros(
  609. (self.n_categorical_splits, 8), dtype=X_BITSET_INNER_DTYPE
  610. )
  611. raw_left_cat_bitsets = np.zeros(
  612. (self.n_categorical_splits, 8), dtype=X_BITSET_INNER_DTYPE
  613. )
  614. _fill_predictor_arrays(
  615. predictor_nodes,
  616. binned_left_cat_bitsets,
  617. raw_left_cat_bitsets,
  618. self.root,
  619. binning_thresholds,
  620. self.n_bins_non_missing,
  621. )
  622. return TreePredictor(
  623. predictor_nodes, binned_left_cat_bitsets, raw_left_cat_bitsets
  624. )
  625. def _fill_predictor_arrays(
  626. predictor_nodes,
  627. binned_left_cat_bitsets,
  628. raw_left_cat_bitsets,
  629. grower_node,
  630. binning_thresholds,
  631. n_bins_non_missing,
  632. next_free_node_idx=0,
  633. next_free_bitset_idx=0,
  634. ):
  635. """Helper used in make_predictor to set the TreePredictor fields."""
  636. node = predictor_nodes[next_free_node_idx]
  637. node["count"] = grower_node.n_samples
  638. node["depth"] = grower_node.depth
  639. if grower_node.split_info is not None:
  640. node["gain"] = grower_node.split_info.gain
  641. else:
  642. node["gain"] = -1
  643. node["value"] = grower_node.value
  644. if grower_node.is_leaf:
  645. # Leaf node
  646. node["is_leaf"] = True
  647. return next_free_node_idx + 1, next_free_bitset_idx
  648. split_info = grower_node.split_info
  649. feature_idx, bin_idx = split_info.feature_idx, split_info.bin_idx
  650. node["feature_idx"] = feature_idx
  651. node["bin_threshold"] = bin_idx
  652. node["missing_go_to_left"] = split_info.missing_go_to_left
  653. node["is_categorical"] = split_info.is_categorical
  654. if split_info.bin_idx == n_bins_non_missing[feature_idx] - 1:
  655. # Split is on the last non-missing bin: it's a "split on nans".
  656. # All nans go to the right, the rest go to the left.
  657. # Note: for categorical splits, bin_idx is 0 and we rely on the bitset
  658. node["num_threshold"] = np.inf
  659. elif split_info.is_categorical:
  660. categories = binning_thresholds[feature_idx]
  661. node["bitset_idx"] = next_free_bitset_idx
  662. binned_left_cat_bitsets[next_free_bitset_idx] = split_info.left_cat_bitset
  663. set_raw_bitset_from_binned_bitset(
  664. raw_left_cat_bitsets[next_free_bitset_idx],
  665. split_info.left_cat_bitset,
  666. categories,
  667. )
  668. next_free_bitset_idx += 1
  669. else:
  670. node["num_threshold"] = binning_thresholds[feature_idx][bin_idx]
  671. next_free_node_idx += 1
  672. node["left"] = next_free_node_idx
  673. next_free_node_idx, next_free_bitset_idx = _fill_predictor_arrays(
  674. predictor_nodes,
  675. binned_left_cat_bitsets,
  676. raw_left_cat_bitsets,
  677. grower_node.left_child,
  678. binning_thresholds=binning_thresholds,
  679. n_bins_non_missing=n_bins_non_missing,
  680. next_free_node_idx=next_free_node_idx,
  681. next_free_bitset_idx=next_free_bitset_idx,
  682. )
  683. node["right"] = next_free_node_idx
  684. return _fill_predictor_arrays(
  685. predictor_nodes,
  686. binned_left_cat_bitsets,
  687. raw_left_cat_bitsets,
  688. grower_node.right_child,
  689. binning_thresholds=binning_thresholds,
  690. n_bins_non_missing=n_bins_non_missing,
  691. next_free_node_idx=next_free_node_idx,
  692. next_free_bitset_idx=next_free_bitset_idx,
  693. )