_splitter.pxd 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Authors: Gilles Louppe <g.louppe@gmail.com>
  2. # Peter Prettenhofer <peter.prettenhofer@gmail.com>
  3. # Brian Holt <bdholt1@gmail.com>
  4. # Joel Nothman <joel.nothman@gmail.com>
  5. # Arnaud Joly <arnaud.v.joly@gmail.com>
  6. # Jacob Schreiber <jmschreiber91@gmail.com>
  7. #
  8. # License: BSD 3 clause
  9. # See _splitter.pyx for details.
  10. from ._criterion cimport Criterion
  11. from ._tree cimport DTYPE_t # Type of X
  12. from ._tree cimport DOUBLE_t # Type of y, sample_weight
  13. from ._tree cimport SIZE_t # Type for indices and counters
  14. from ._tree cimport INT32_t # Signed 32 bit integer
  15. from ._tree cimport UINT32_t # Unsigned 32 bit integer
  16. cdef struct SplitRecord:
  17. # Data to track sample split
  18. SIZE_t feature # Which feature to split on.
  19. SIZE_t pos # Split samples array at the given position,
  20. # # i.e. count of samples below threshold for feature.
  21. # # pos is >= end if the node is a leaf.
  22. double threshold # Threshold to split at.
  23. double improvement # Impurity improvement given parent node.
  24. double impurity_left # Impurity of the left split.
  25. double impurity_right # Impurity of the right split.
  26. unsigned char missing_go_to_left # Controls if missing values go to the left node.
  27. SIZE_t n_missing # Number of missing values for the feature being split on
  28. cdef class Splitter:
  29. # The splitter searches in the input space for a feature and a threshold
  30. # to split the samples samples[start:end].
  31. #
  32. # The impurity computations are delegated to a criterion object.
  33. # Internal structures
  34. cdef public Criterion criterion # Impurity criterion
  35. cdef public SIZE_t max_features # Number of features to test
  36. cdef public SIZE_t min_samples_leaf # Min samples in a leaf
  37. cdef public double min_weight_leaf # Minimum weight in a leaf
  38. cdef object random_state # Random state
  39. cdef UINT32_t rand_r_state # sklearn_rand_r random number state
  40. cdef SIZE_t[::1] samples # Sample indices in X, y
  41. cdef SIZE_t n_samples # X.shape[0]
  42. cdef double weighted_n_samples # Weighted number of samples
  43. cdef SIZE_t[::1] features # Feature indices in X
  44. cdef SIZE_t[::1] constant_features # Constant features indices
  45. cdef SIZE_t n_features # X.shape[1]
  46. cdef DTYPE_t[::1] feature_values # temp. array holding feature values
  47. cdef SIZE_t start # Start position for the current node
  48. cdef SIZE_t end # End position for the current node
  49. cdef const DOUBLE_t[:, ::1] y
  50. cdef const DOUBLE_t[:] sample_weight
  51. # The samples vector `samples` is maintained by the Splitter object such
  52. # that the samples contained in a node are contiguous. With this setting,
  53. # `node_split` reorganizes the node samples `samples[start:end]` in two
  54. # subsets `samples[start:pos]` and `samples[pos:end]`.
  55. # The 1-d `features` array of size n_features contains the features
  56. # indices and allows fast sampling without replacement of features.
  57. # The 1-d `constant_features` array of size n_features holds in
  58. # `constant_features[:n_constant_features]` the feature ids with
  59. # constant values for all the samples that reached a specific node.
  60. # The value `n_constant_features` is given by the parent node to its
  61. # child nodes. The content of the range `[n_constant_features:]` is left
  62. # undefined, but preallocated for performance reasons
  63. # This allows optimization with depth-based tree building.
  64. # Methods
  65. cdef int init(
  66. self,
  67. object X,
  68. const DOUBLE_t[:, ::1] y,
  69. const DOUBLE_t[:] sample_weight,
  70. const unsigned char[::1] missing_values_in_feature_mask,
  71. ) except -1
  72. cdef int node_reset(
  73. self,
  74. SIZE_t start,
  75. SIZE_t end,
  76. double* weighted_n_node_samples
  77. ) except -1 nogil
  78. cdef int node_split(
  79. self,
  80. double impurity, # Impurity of the node
  81. SplitRecord* split,
  82. SIZE_t* n_constant_features
  83. ) except -1 nogil
  84. cdef void node_value(self, double* dest) noexcept nogil
  85. cdef double node_impurity(self) noexcept nogil