_forest.py 106 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826
  1. """
  2. Forest of trees-based ensemble methods.
  3. Those methods include random forests and extremely randomized trees.
  4. The module structure is the following:
  5. - The ``BaseForest`` base class implements a common ``fit`` method for all
  6. the estimators in the module. The ``fit`` method of the base ``Forest``
  7. class calls the ``fit`` method of each sub-estimator on random samples
  8. (with replacement, a.k.a. bootstrap) of the training set.
  9. The init of the sub-estimator is further delegated to the
  10. ``BaseEnsemble`` constructor.
  11. - The ``ForestClassifier`` and ``ForestRegressor`` base classes further
  12. implement the prediction logic by computing an average of the predicted
  13. outcomes of the sub-estimators.
  14. - The ``RandomForestClassifier`` and ``RandomForestRegressor`` derived
  15. classes provide the user with concrete implementations of
  16. the forest ensemble method using classical, deterministic
  17. ``DecisionTreeClassifier`` and ``DecisionTreeRegressor`` as
  18. sub-estimator implementations.
  19. - The ``ExtraTreesClassifier`` and ``ExtraTreesRegressor`` derived
  20. classes provide the user with concrete implementations of the
  21. forest ensemble method using the extremely randomized trees
  22. ``ExtraTreeClassifier`` and ``ExtraTreeRegressor`` as
  23. sub-estimator implementations.
  24. Single and multi-output problems are both handled.
  25. """
  26. # Authors: Gilles Louppe <g.louppe@gmail.com>
  27. # Brian Holt <bdholt1@gmail.com>
  28. # Joly Arnaud <arnaud.v.joly@gmail.com>
  29. # Fares Hedayati <fares.hedayati@gmail.com>
  30. #
  31. # License: BSD 3 clause
  32. import threading
  33. from abc import ABCMeta, abstractmethod
  34. from numbers import Integral, Real
  35. from warnings import catch_warnings, simplefilter, warn
  36. import numpy as np
  37. from scipy.sparse import hstack as sparse_hstack
  38. from scipy.sparse import issparse
  39. from ..base import (
  40. ClassifierMixin,
  41. MultiOutputMixin,
  42. RegressorMixin,
  43. TransformerMixin,
  44. _fit_context,
  45. is_classifier,
  46. )
  47. from ..exceptions import DataConversionWarning
  48. from ..metrics import accuracy_score, r2_score
  49. from ..preprocessing import OneHotEncoder
  50. from ..tree import (
  51. BaseDecisionTree,
  52. DecisionTreeClassifier,
  53. DecisionTreeRegressor,
  54. ExtraTreeClassifier,
  55. ExtraTreeRegressor,
  56. )
  57. from ..tree._tree import DOUBLE, DTYPE
  58. from ..utils import check_random_state, compute_sample_weight
  59. from ..utils._param_validation import Interval, RealNotInt, StrOptions
  60. from ..utils.multiclass import check_classification_targets, type_of_target
  61. from ..utils.parallel import Parallel, delayed
  62. from ..utils.validation import (
  63. _check_feature_names_in,
  64. _check_sample_weight,
  65. _num_samples,
  66. check_is_fitted,
  67. )
  68. from ._base import BaseEnsemble, _partition_estimators
  69. __all__ = [
  70. "RandomForestClassifier",
  71. "RandomForestRegressor",
  72. "ExtraTreesClassifier",
  73. "ExtraTreesRegressor",
  74. "RandomTreesEmbedding",
  75. ]
  76. MAX_INT = np.iinfo(np.int32).max
  77. def _get_n_samples_bootstrap(n_samples, max_samples):
  78. """
  79. Get the number of samples in a bootstrap sample.
  80. Parameters
  81. ----------
  82. n_samples : int
  83. Number of samples in the dataset.
  84. max_samples : int or float
  85. The maximum number of samples to draw from the total available:
  86. - if float, this indicates a fraction of the total and should be
  87. the interval `(0.0, 1.0]`;
  88. - if int, this indicates the exact number of samples;
  89. - if None, this indicates the total number of samples.
  90. Returns
  91. -------
  92. n_samples_bootstrap : int
  93. The total number of samples to draw for the bootstrap sample.
  94. """
  95. if max_samples is None:
  96. return n_samples
  97. if isinstance(max_samples, Integral):
  98. if max_samples > n_samples:
  99. msg = "`max_samples` must be <= n_samples={} but got value {}"
  100. raise ValueError(msg.format(n_samples, max_samples))
  101. return max_samples
  102. if isinstance(max_samples, Real):
  103. return max(round(n_samples * max_samples), 1)
  104. def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap):
  105. """
  106. Private function used to _parallel_build_trees function."""
  107. random_instance = check_random_state(random_state)
  108. sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap)
  109. return sample_indices
  110. def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap):
  111. """
  112. Private function used to forest._set_oob_score function."""
  113. sample_indices = _generate_sample_indices(
  114. random_state, n_samples, n_samples_bootstrap
  115. )
  116. sample_counts = np.bincount(sample_indices, minlength=n_samples)
  117. unsampled_mask = sample_counts == 0
  118. indices_range = np.arange(n_samples)
  119. unsampled_indices = indices_range[unsampled_mask]
  120. return unsampled_indices
  121. def _parallel_build_trees(
  122. tree,
  123. bootstrap,
  124. X,
  125. y,
  126. sample_weight,
  127. tree_idx,
  128. n_trees,
  129. verbose=0,
  130. class_weight=None,
  131. n_samples_bootstrap=None,
  132. ):
  133. """
  134. Private function used to fit a single tree in parallel."""
  135. if verbose > 1:
  136. print("building tree %d of %d" % (tree_idx + 1, n_trees))
  137. if bootstrap:
  138. n_samples = X.shape[0]
  139. if sample_weight is None:
  140. curr_sample_weight = np.ones((n_samples,), dtype=np.float64)
  141. else:
  142. curr_sample_weight = sample_weight.copy()
  143. indices = _generate_sample_indices(
  144. tree.random_state, n_samples, n_samples_bootstrap
  145. )
  146. sample_counts = np.bincount(indices, minlength=n_samples)
  147. curr_sample_weight *= sample_counts
  148. if class_weight == "subsample":
  149. with catch_warnings():
  150. simplefilter("ignore", DeprecationWarning)
  151. curr_sample_weight *= compute_sample_weight("auto", y, indices=indices)
  152. elif class_weight == "balanced_subsample":
  153. curr_sample_weight *= compute_sample_weight("balanced", y, indices=indices)
  154. tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
  155. else:
  156. tree.fit(X, y, sample_weight=sample_weight, check_input=False)
  157. return tree
  158. class BaseForest(MultiOutputMixin, BaseEnsemble, metaclass=ABCMeta):
  159. """
  160. Base class for forests of trees.
  161. Warning: This class should not be used directly. Use derived classes
  162. instead.
  163. """
  164. _parameter_constraints: dict = {
  165. "n_estimators": [Interval(Integral, 1, None, closed="left")],
  166. "bootstrap": ["boolean"],
  167. "oob_score": ["boolean", callable],
  168. "n_jobs": [Integral, None],
  169. "random_state": ["random_state"],
  170. "verbose": ["verbose"],
  171. "warm_start": ["boolean"],
  172. "max_samples": [
  173. None,
  174. Interval(RealNotInt, 0.0, 1.0, closed="right"),
  175. Interval(Integral, 1, None, closed="left"),
  176. ],
  177. }
  178. @abstractmethod
  179. def __init__(
  180. self,
  181. estimator,
  182. n_estimators=100,
  183. *,
  184. estimator_params=tuple(),
  185. bootstrap=False,
  186. oob_score=False,
  187. n_jobs=None,
  188. random_state=None,
  189. verbose=0,
  190. warm_start=False,
  191. class_weight=None,
  192. max_samples=None,
  193. base_estimator="deprecated",
  194. ):
  195. super().__init__(
  196. estimator=estimator,
  197. n_estimators=n_estimators,
  198. estimator_params=estimator_params,
  199. base_estimator=base_estimator,
  200. )
  201. self.bootstrap = bootstrap
  202. self.oob_score = oob_score
  203. self.n_jobs = n_jobs
  204. self.random_state = random_state
  205. self.verbose = verbose
  206. self.warm_start = warm_start
  207. self.class_weight = class_weight
  208. self.max_samples = max_samples
  209. def apply(self, X):
  210. """
  211. Apply trees in the forest to X, return leaf indices.
  212. Parameters
  213. ----------
  214. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  215. The input samples. Internally, its dtype will be converted to
  216. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  217. converted into a sparse ``csr_matrix``.
  218. Returns
  219. -------
  220. X_leaves : ndarray of shape (n_samples, n_estimators)
  221. For each datapoint x in X and for each tree in the forest,
  222. return the index of the leaf x ends up in.
  223. """
  224. X = self._validate_X_predict(X)
  225. results = Parallel(
  226. n_jobs=self.n_jobs,
  227. verbose=self.verbose,
  228. prefer="threads",
  229. )(delayed(tree.apply)(X, check_input=False) for tree in self.estimators_)
  230. return np.array(results).T
  231. def decision_path(self, X):
  232. """
  233. Return the decision path in the forest.
  234. .. versionadded:: 0.18
  235. Parameters
  236. ----------
  237. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  238. The input samples. Internally, its dtype will be converted to
  239. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  240. converted into a sparse ``csr_matrix``.
  241. Returns
  242. -------
  243. indicator : sparse matrix of shape (n_samples, n_nodes)
  244. Return a node indicator matrix where non zero elements indicates
  245. that the samples goes through the nodes. The matrix is of CSR
  246. format.
  247. n_nodes_ptr : ndarray of shape (n_estimators + 1,)
  248. The columns from indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]]
  249. gives the indicator value for the i-th estimator.
  250. """
  251. X = self._validate_X_predict(X)
  252. indicators = Parallel(
  253. n_jobs=self.n_jobs,
  254. verbose=self.verbose,
  255. prefer="threads",
  256. )(
  257. delayed(tree.decision_path)(X, check_input=False)
  258. for tree in self.estimators_
  259. )
  260. n_nodes = [0]
  261. n_nodes.extend([i.shape[1] for i in indicators])
  262. n_nodes_ptr = np.array(n_nodes).cumsum()
  263. return sparse_hstack(indicators).tocsr(), n_nodes_ptr
  264. @_fit_context(prefer_skip_nested_validation=True)
  265. def fit(self, X, y, sample_weight=None):
  266. """
  267. Build a forest of trees from the training set (X, y).
  268. Parameters
  269. ----------
  270. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  271. The training input samples. Internally, its dtype will be converted
  272. to ``dtype=np.float32``. If a sparse matrix is provided, it will be
  273. converted into a sparse ``csc_matrix``.
  274. y : array-like of shape (n_samples,) or (n_samples, n_outputs)
  275. The target values (class labels in classification, real numbers in
  276. regression).
  277. sample_weight : array-like of shape (n_samples,), default=None
  278. Sample weights. If None, then samples are equally weighted. Splits
  279. that would create child nodes with net zero or negative weight are
  280. ignored while searching for a split in each node. In the case of
  281. classification, splits are also ignored if they would result in any
  282. single class carrying a negative weight in either child node.
  283. Returns
  284. -------
  285. self : object
  286. Fitted estimator.
  287. """
  288. # Validate or convert input data
  289. if issparse(y):
  290. raise ValueError("sparse multilabel-indicator for y is not supported.")
  291. X, y = self._validate_data(
  292. X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE
  293. )
  294. if sample_weight is not None:
  295. sample_weight = _check_sample_weight(sample_weight, X)
  296. if issparse(X):
  297. # Pre-sort indices to avoid that each individual tree of the
  298. # ensemble sorts the indices.
  299. X.sort_indices()
  300. y = np.atleast_1d(y)
  301. if y.ndim == 2 and y.shape[1] == 1:
  302. warn(
  303. (
  304. "A column-vector y was passed when a 1d array was"
  305. " expected. Please change the shape of y to "
  306. "(n_samples,), for example using ravel()."
  307. ),
  308. DataConversionWarning,
  309. stacklevel=2,
  310. )
  311. if y.ndim == 1:
  312. # reshape is necessary to preserve the data contiguity against vs
  313. # [:, np.newaxis] that does not.
  314. y = np.reshape(y, (-1, 1))
  315. if self.criterion == "poisson":
  316. if np.any(y < 0):
  317. raise ValueError(
  318. "Some value(s) of y are negative which is "
  319. "not allowed for Poisson regression."
  320. )
  321. if np.sum(y) <= 0:
  322. raise ValueError(
  323. "Sum of y is not strictly positive which "
  324. "is necessary for Poisson regression."
  325. )
  326. self.n_outputs_ = y.shape[1]
  327. y, expanded_class_weight = self._validate_y_class_weight(y)
  328. if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
  329. y = np.ascontiguousarray(y, dtype=DOUBLE)
  330. if expanded_class_weight is not None:
  331. if sample_weight is not None:
  332. sample_weight = sample_weight * expanded_class_weight
  333. else:
  334. sample_weight = expanded_class_weight
  335. if not self.bootstrap and self.max_samples is not None:
  336. raise ValueError(
  337. "`max_sample` cannot be set if `bootstrap=False`. "
  338. "Either switch to `bootstrap=True` or set "
  339. "`max_sample=None`."
  340. )
  341. elif self.bootstrap:
  342. n_samples_bootstrap = _get_n_samples_bootstrap(
  343. n_samples=X.shape[0], max_samples=self.max_samples
  344. )
  345. else:
  346. n_samples_bootstrap = None
  347. self._validate_estimator()
  348. if not self.bootstrap and self.oob_score:
  349. raise ValueError("Out of bag estimation only available if bootstrap=True")
  350. random_state = check_random_state(self.random_state)
  351. if not self.warm_start or not hasattr(self, "estimators_"):
  352. # Free allocated memory, if any
  353. self.estimators_ = []
  354. n_more_estimators = self.n_estimators - len(self.estimators_)
  355. if n_more_estimators < 0:
  356. raise ValueError(
  357. "n_estimators=%d must be larger or equal to "
  358. "len(estimators_)=%d when warm_start==True"
  359. % (self.n_estimators, len(self.estimators_))
  360. )
  361. elif n_more_estimators == 0:
  362. warn(
  363. "Warm-start fitting without increasing n_estimators does not "
  364. "fit new trees."
  365. )
  366. else:
  367. if self.warm_start and len(self.estimators_) > 0:
  368. # We draw from the random state to get the random state we
  369. # would have got if we hadn't used a warm_start.
  370. random_state.randint(MAX_INT, size=len(self.estimators_))
  371. trees = [
  372. self._make_estimator(append=False, random_state=random_state)
  373. for i in range(n_more_estimators)
  374. ]
  375. # Parallel loop: we prefer the threading backend as the Cython code
  376. # for fitting the trees is internally releasing the Python GIL
  377. # making threading more efficient than multiprocessing in
  378. # that case. However, for joblib 0.12+ we respect any
  379. # parallel_backend contexts set at a higher level,
  380. # since correctness does not rely on using threads.
  381. trees = Parallel(
  382. n_jobs=self.n_jobs,
  383. verbose=self.verbose,
  384. prefer="threads",
  385. )(
  386. delayed(_parallel_build_trees)(
  387. t,
  388. self.bootstrap,
  389. X,
  390. y,
  391. sample_weight,
  392. i,
  393. len(trees),
  394. verbose=self.verbose,
  395. class_weight=self.class_weight,
  396. n_samples_bootstrap=n_samples_bootstrap,
  397. )
  398. for i, t in enumerate(trees)
  399. )
  400. # Collect newly grown trees
  401. self.estimators_.extend(trees)
  402. if self.oob_score and (
  403. n_more_estimators > 0 or not hasattr(self, "oob_score_")
  404. ):
  405. y_type = type_of_target(y)
  406. if y_type in ("multiclass-multioutput", "unknown"):
  407. # FIXME: we could consider to support multiclass-multioutput if
  408. # we introduce or reuse a constructor parameter (e.g.
  409. # oob_score) allowing our user to pass a callable defining the
  410. # scoring strategy on OOB sample.
  411. raise ValueError(
  412. "The type of target cannot be used to compute OOB "
  413. f"estimates. Got {y_type} while only the following are "
  414. "supported: continuous, continuous-multioutput, binary, "
  415. "multiclass, multilabel-indicator."
  416. )
  417. if callable(self.oob_score):
  418. self._set_oob_score_and_attributes(
  419. X, y, scoring_function=self.oob_score
  420. )
  421. else:
  422. self._set_oob_score_and_attributes(X, y)
  423. # Decapsulate classes_ attributes
  424. if hasattr(self, "classes_") and self.n_outputs_ == 1:
  425. self.n_classes_ = self.n_classes_[0]
  426. self.classes_ = self.classes_[0]
  427. return self
  428. @abstractmethod
  429. def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
  430. """Compute and set the OOB score and attributes.
  431. Parameters
  432. ----------
  433. X : array-like of shape (n_samples, n_features)
  434. The data matrix.
  435. y : ndarray of shape (n_samples, n_outputs)
  436. The target matrix.
  437. scoring_function : callable, default=None
  438. Scoring function for OOB score. Default depends on whether
  439. this is a regression (R2 score) or classification problem
  440. (accuracy score).
  441. """
  442. def _compute_oob_predictions(self, X, y):
  443. """Compute and set the OOB score.
  444. Parameters
  445. ----------
  446. X : array-like of shape (n_samples, n_features)
  447. The data matrix.
  448. y : ndarray of shape (n_samples, n_outputs)
  449. The target matrix.
  450. Returns
  451. -------
  452. oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \
  453. (n_samples, 1, n_outputs)
  454. The OOB predictions.
  455. """
  456. # Prediction requires X to be in CSR format
  457. if issparse(X):
  458. X = X.tocsr()
  459. n_samples = y.shape[0]
  460. n_outputs = self.n_outputs_
  461. if is_classifier(self) and hasattr(self, "n_classes_"):
  462. # n_classes_ is a ndarray at this stage
  463. # all the supported type of target will have the same number of
  464. # classes in all outputs
  465. oob_pred_shape = (n_samples, self.n_classes_[0], n_outputs)
  466. else:
  467. # for regression, n_classes_ does not exist and we create an empty
  468. # axis to be consistent with the classification case and make
  469. # the array operations compatible with the 2 settings
  470. oob_pred_shape = (n_samples, 1, n_outputs)
  471. oob_pred = np.zeros(shape=oob_pred_shape, dtype=np.float64)
  472. n_oob_pred = np.zeros((n_samples, n_outputs), dtype=np.int64)
  473. n_samples_bootstrap = _get_n_samples_bootstrap(
  474. n_samples,
  475. self.max_samples,
  476. )
  477. for estimator in self.estimators_:
  478. unsampled_indices = _generate_unsampled_indices(
  479. estimator.random_state,
  480. n_samples,
  481. n_samples_bootstrap,
  482. )
  483. y_pred = self._get_oob_predictions(estimator, X[unsampled_indices, :])
  484. oob_pred[unsampled_indices, ...] += y_pred
  485. n_oob_pred[unsampled_indices, :] += 1
  486. for k in range(n_outputs):
  487. if (n_oob_pred == 0).any():
  488. warn(
  489. (
  490. "Some inputs do not have OOB scores. This probably means "
  491. "too few trees were used to compute any reliable OOB "
  492. "estimates."
  493. ),
  494. UserWarning,
  495. )
  496. n_oob_pred[n_oob_pred == 0] = 1
  497. oob_pred[..., k] /= n_oob_pred[..., [k]]
  498. return oob_pred
  499. def _validate_y_class_weight(self, y):
  500. # Default implementation
  501. return y, None
  502. def _validate_X_predict(self, X):
  503. """
  504. Validate X whenever one tries to predict, apply, predict_proba."""
  505. check_is_fitted(self)
  506. X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
  507. if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc):
  508. raise ValueError("No support for np.int64 index based sparse matrices")
  509. return X
  510. @property
  511. def feature_importances_(self):
  512. """
  513. The impurity-based feature importances.
  514. The higher, the more important the feature.
  515. The importance of a feature is computed as the (normalized)
  516. total reduction of the criterion brought by that feature. It is also
  517. known as the Gini importance.
  518. Warning: impurity-based feature importances can be misleading for
  519. high cardinality features (many unique values). See
  520. :func:`sklearn.inspection.permutation_importance` as an alternative.
  521. Returns
  522. -------
  523. feature_importances_ : ndarray of shape (n_features,)
  524. The values of this array sum to 1, unless all trees are single node
  525. trees consisting of only the root node, in which case it will be an
  526. array of zeros.
  527. """
  528. check_is_fitted(self)
  529. all_importances = Parallel(n_jobs=self.n_jobs, prefer="threads")(
  530. delayed(getattr)(tree, "feature_importances_")
  531. for tree in self.estimators_
  532. if tree.tree_.node_count > 1
  533. )
  534. if not all_importances:
  535. return np.zeros(self.n_features_in_, dtype=np.float64)
  536. all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
  537. return all_importances / np.sum(all_importances)
  538. def _accumulate_prediction(predict, X, out, lock):
  539. """
  540. This is a utility function for joblib's Parallel.
  541. It can't go locally in ForestClassifier or ForestRegressor, because joblib
  542. complains that it cannot pickle it when placed there.
  543. """
  544. prediction = predict(X, check_input=False)
  545. with lock:
  546. if len(out) == 1:
  547. out[0] += prediction
  548. else:
  549. for i in range(len(out)):
  550. out[i] += prediction[i]
  551. class ForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
  552. """
  553. Base class for forest of trees-based classifiers.
  554. Warning: This class should not be used directly. Use derived classes
  555. instead.
  556. """
  557. @abstractmethod
  558. def __init__(
  559. self,
  560. estimator,
  561. n_estimators=100,
  562. *,
  563. estimator_params=tuple(),
  564. bootstrap=False,
  565. oob_score=False,
  566. n_jobs=None,
  567. random_state=None,
  568. verbose=0,
  569. warm_start=False,
  570. class_weight=None,
  571. max_samples=None,
  572. base_estimator="deprecated",
  573. ):
  574. super().__init__(
  575. estimator=estimator,
  576. n_estimators=n_estimators,
  577. estimator_params=estimator_params,
  578. bootstrap=bootstrap,
  579. oob_score=oob_score,
  580. n_jobs=n_jobs,
  581. random_state=random_state,
  582. verbose=verbose,
  583. warm_start=warm_start,
  584. class_weight=class_weight,
  585. max_samples=max_samples,
  586. base_estimator=base_estimator,
  587. )
  588. @staticmethod
  589. def _get_oob_predictions(tree, X):
  590. """Compute the OOB predictions for an individual tree.
  591. Parameters
  592. ----------
  593. tree : DecisionTreeClassifier object
  594. A single decision tree classifier.
  595. X : ndarray of shape (n_samples, n_features)
  596. The OOB samples.
  597. Returns
  598. -------
  599. y_pred : ndarray of shape (n_samples, n_classes, n_outputs)
  600. The OOB associated predictions.
  601. """
  602. y_pred = tree.predict_proba(X, check_input=False)
  603. y_pred = np.array(y_pred, copy=False)
  604. if y_pred.ndim == 2:
  605. # binary and multiclass
  606. y_pred = y_pred[..., np.newaxis]
  607. else:
  608. # Roll the first `n_outputs` axis to the last axis. We will reshape
  609. # from a shape of (n_outputs, n_samples, n_classes) to a shape of
  610. # (n_samples, n_classes, n_outputs).
  611. y_pred = np.rollaxis(y_pred, axis=0, start=3)
  612. return y_pred
  613. def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
  614. """Compute and set the OOB score and attributes.
  615. Parameters
  616. ----------
  617. X : array-like of shape (n_samples, n_features)
  618. The data matrix.
  619. y : ndarray of shape (n_samples, n_outputs)
  620. The target matrix.
  621. scoring_function : callable, default=None
  622. Scoring function for OOB score. Defaults to `accuracy_score`.
  623. """
  624. self.oob_decision_function_ = super()._compute_oob_predictions(X, y)
  625. if self.oob_decision_function_.shape[-1] == 1:
  626. # drop the n_outputs axis if there is a single output
  627. self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1)
  628. if scoring_function is None:
  629. scoring_function = accuracy_score
  630. self.oob_score_ = scoring_function(
  631. y, np.argmax(self.oob_decision_function_, axis=1)
  632. )
  633. def _validate_y_class_weight(self, y):
  634. check_classification_targets(y)
  635. y = np.copy(y)
  636. expanded_class_weight = None
  637. if self.class_weight is not None:
  638. y_original = np.copy(y)
  639. self.classes_ = []
  640. self.n_classes_ = []
  641. y_store_unique_indices = np.zeros(y.shape, dtype=int)
  642. for k in range(self.n_outputs_):
  643. classes_k, y_store_unique_indices[:, k] = np.unique(
  644. y[:, k], return_inverse=True
  645. )
  646. self.classes_.append(classes_k)
  647. self.n_classes_.append(classes_k.shape[0])
  648. y = y_store_unique_indices
  649. if self.class_weight is not None:
  650. valid_presets = ("balanced", "balanced_subsample")
  651. if isinstance(self.class_weight, str):
  652. if self.class_weight not in valid_presets:
  653. raise ValueError(
  654. "Valid presets for class_weight include "
  655. '"balanced" and "balanced_subsample".'
  656. 'Given "%s".'
  657. % self.class_weight
  658. )
  659. if self.warm_start:
  660. warn(
  661. 'class_weight presets "balanced" or '
  662. '"balanced_subsample" are '
  663. "not recommended for warm_start if the fitted data "
  664. "differs from the full dataset. In order to use "
  665. '"balanced" weights, use compute_class_weight '
  666. '("balanced", classes, y). In place of y you can use '
  667. "a large enough sample of the full training set "
  668. "target to properly estimate the class frequency "
  669. "distributions. Pass the resulting weights as the "
  670. "class_weight parameter."
  671. )
  672. if self.class_weight != "balanced_subsample" or not self.bootstrap:
  673. if self.class_weight == "balanced_subsample":
  674. class_weight = "balanced"
  675. else:
  676. class_weight = self.class_weight
  677. expanded_class_weight = compute_sample_weight(class_weight, y_original)
  678. return y, expanded_class_weight
  679. def predict(self, X):
  680. """
  681. Predict class for X.
  682. The predicted class of an input sample is a vote by the trees in
  683. the forest, weighted by their probability estimates. That is,
  684. the predicted class is the one with highest mean probability
  685. estimate across the trees.
  686. Parameters
  687. ----------
  688. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  689. The input samples. Internally, its dtype will be converted to
  690. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  691. converted into a sparse ``csr_matrix``.
  692. Returns
  693. -------
  694. y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
  695. The predicted classes.
  696. """
  697. proba = self.predict_proba(X)
  698. if self.n_outputs_ == 1:
  699. return self.classes_.take(np.argmax(proba, axis=1), axis=0)
  700. else:
  701. n_samples = proba[0].shape[0]
  702. # all dtypes should be the same, so just take the first
  703. class_type = self.classes_[0].dtype
  704. predictions = np.empty((n_samples, self.n_outputs_), dtype=class_type)
  705. for k in range(self.n_outputs_):
  706. predictions[:, k] = self.classes_[k].take(
  707. np.argmax(proba[k], axis=1), axis=0
  708. )
  709. return predictions
  710. def predict_proba(self, X):
  711. """
  712. Predict class probabilities for X.
  713. The predicted class probabilities of an input sample are computed as
  714. the mean predicted class probabilities of the trees in the forest.
  715. The class probability of a single tree is the fraction of samples of
  716. the same class in a leaf.
  717. Parameters
  718. ----------
  719. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  720. The input samples. Internally, its dtype will be converted to
  721. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  722. converted into a sparse ``csr_matrix``.
  723. Returns
  724. -------
  725. p : ndarray of shape (n_samples, n_classes), or a list of such arrays
  726. The class probabilities of the input samples. The order of the
  727. classes corresponds to that in the attribute :term:`classes_`.
  728. """
  729. check_is_fitted(self)
  730. # Check data
  731. X = self._validate_X_predict(X)
  732. # Assign chunk of trees to jobs
  733. n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
  734. # avoid storing the output of every estimator by summing them here
  735. all_proba = [
  736. np.zeros((X.shape[0], j), dtype=np.float64)
  737. for j in np.atleast_1d(self.n_classes_)
  738. ]
  739. lock = threading.Lock()
  740. Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
  741. delayed(_accumulate_prediction)(e.predict_proba, X, all_proba, lock)
  742. for e in self.estimators_
  743. )
  744. for proba in all_proba:
  745. proba /= len(self.estimators_)
  746. if len(all_proba) == 1:
  747. return all_proba[0]
  748. else:
  749. return all_proba
  750. def predict_log_proba(self, X):
  751. """
  752. Predict class log-probabilities for X.
  753. The predicted class log-probabilities of an input sample is computed as
  754. the log of the mean predicted class probabilities of the trees in the
  755. forest.
  756. Parameters
  757. ----------
  758. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  759. The input samples. Internally, its dtype will be converted to
  760. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  761. converted into a sparse ``csr_matrix``.
  762. Returns
  763. -------
  764. p : ndarray of shape (n_samples, n_classes), or a list of such arrays
  765. The class probabilities of the input samples. The order of the
  766. classes corresponds to that in the attribute :term:`classes_`.
  767. """
  768. proba = self.predict_proba(X)
  769. if self.n_outputs_ == 1:
  770. return np.log(proba)
  771. else:
  772. for k in range(self.n_outputs_):
  773. proba[k] = np.log(proba[k])
  774. return proba
  775. def _more_tags(self):
  776. return {"multilabel": True}
  777. class ForestRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
  778. """
  779. Base class for forest of trees-based regressors.
  780. Warning: This class should not be used directly. Use derived classes
  781. instead.
  782. """
  783. @abstractmethod
  784. def __init__(
  785. self,
  786. estimator,
  787. n_estimators=100,
  788. *,
  789. estimator_params=tuple(),
  790. bootstrap=False,
  791. oob_score=False,
  792. n_jobs=None,
  793. random_state=None,
  794. verbose=0,
  795. warm_start=False,
  796. max_samples=None,
  797. base_estimator="deprecated",
  798. ):
  799. super().__init__(
  800. estimator,
  801. n_estimators=n_estimators,
  802. estimator_params=estimator_params,
  803. bootstrap=bootstrap,
  804. oob_score=oob_score,
  805. n_jobs=n_jobs,
  806. random_state=random_state,
  807. verbose=verbose,
  808. warm_start=warm_start,
  809. max_samples=max_samples,
  810. base_estimator=base_estimator,
  811. )
  812. def predict(self, X):
  813. """
  814. Predict regression target for X.
  815. The predicted regression target of an input sample is computed as the
  816. mean predicted regression targets of the trees in the forest.
  817. Parameters
  818. ----------
  819. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  820. The input samples. Internally, its dtype will be converted to
  821. ``dtype=np.float32``. If a sparse matrix is provided, it will be
  822. converted into a sparse ``csr_matrix``.
  823. Returns
  824. -------
  825. y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
  826. The predicted values.
  827. """
  828. check_is_fitted(self)
  829. # Check data
  830. X = self._validate_X_predict(X)
  831. # Assign chunk of trees to jobs
  832. n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
  833. # avoid storing the output of every estimator by summing them here
  834. if self.n_outputs_ > 1:
  835. y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
  836. else:
  837. y_hat = np.zeros((X.shape[0]), dtype=np.float64)
  838. # Parallel loop
  839. lock = threading.Lock()
  840. Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
  841. delayed(_accumulate_prediction)(e.predict, X, [y_hat], lock)
  842. for e in self.estimators_
  843. )
  844. y_hat /= len(self.estimators_)
  845. return y_hat
  846. @staticmethod
  847. def _get_oob_predictions(tree, X):
  848. """Compute the OOB predictions for an individual tree.
  849. Parameters
  850. ----------
  851. tree : DecisionTreeRegressor object
  852. A single decision tree regressor.
  853. X : ndarray of shape (n_samples, n_features)
  854. The OOB samples.
  855. Returns
  856. -------
  857. y_pred : ndarray of shape (n_samples, 1, n_outputs)
  858. The OOB associated predictions.
  859. """
  860. y_pred = tree.predict(X, check_input=False)
  861. if y_pred.ndim == 1:
  862. # single output regression
  863. y_pred = y_pred[:, np.newaxis, np.newaxis]
  864. else:
  865. # multioutput regression
  866. y_pred = y_pred[:, np.newaxis, :]
  867. return y_pred
  868. def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
  869. """Compute and set the OOB score and attributes.
  870. Parameters
  871. ----------
  872. X : array-like of shape (n_samples, n_features)
  873. The data matrix.
  874. y : ndarray of shape (n_samples, n_outputs)
  875. The target matrix.
  876. scoring_function : callable, default=None
  877. Scoring function for OOB score. Defaults to `r2_score`.
  878. """
  879. self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze(axis=1)
  880. if self.oob_prediction_.shape[-1] == 1:
  881. # drop the n_outputs axis if there is a single output
  882. self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1)
  883. if scoring_function is None:
  884. scoring_function = r2_score
  885. self.oob_score_ = scoring_function(y, self.oob_prediction_)
  886. def _compute_partial_dependence_recursion(self, grid, target_features):
  887. """Fast partial dependence computation.
  888. Parameters
  889. ----------
  890. grid : ndarray of shape (n_samples, n_target_features)
  891. The grid points on which the partial dependence should be
  892. evaluated.
  893. target_features : ndarray of shape (n_target_features)
  894. The set of target features for which the partial dependence
  895. should be evaluated.
  896. Returns
  897. -------
  898. averaged_predictions : ndarray of shape (n_samples,)
  899. The value of the partial dependence function on each grid point.
  900. """
  901. grid = np.asarray(grid, dtype=DTYPE, order="C")
  902. averaged_predictions = np.zeros(
  903. shape=grid.shape[0], dtype=np.float64, order="C"
  904. )
  905. for tree in self.estimators_:
  906. # Note: we don't sum in parallel because the GIL isn't released in
  907. # the fast method.
  908. tree.tree_.compute_partial_dependence(
  909. grid, target_features, averaged_predictions
  910. )
  911. # Average over the forest
  912. averaged_predictions /= len(self.estimators_)
  913. return averaged_predictions
  914. def _more_tags(self):
  915. return {"multilabel": True}
  916. class RandomForestClassifier(ForestClassifier):
  917. """
  918. A random forest classifier.
  919. A random forest is a meta estimator that fits a number of decision tree
  920. classifiers on various sub-samples of the dataset and uses averaging to
  921. improve the predictive accuracy and control over-fitting.
  922. The sub-sample size is controlled with the `max_samples` parameter if
  923. `bootstrap=True` (default), otherwise the whole dataset is used to build
  924. each tree.
  925. For a comparison between tree-based ensemble models see the example
  926. :ref:`sphx_glr_auto_examples_ensemble_plot_forest_hist_grad_boosting_comparison.py`.
  927. Read more in the :ref:`User Guide <forest>`.
  928. Parameters
  929. ----------
  930. n_estimators : int, default=100
  931. The number of trees in the forest.
  932. .. versionchanged:: 0.22
  933. The default value of ``n_estimators`` changed from 10 to 100
  934. in 0.22.
  935. criterion : {"gini", "entropy", "log_loss"}, default="gini"
  936. The function to measure the quality of a split. Supported criteria are
  937. "gini" for the Gini impurity and "log_loss" and "entropy" both for the
  938. Shannon information gain, see :ref:`tree_mathematical_formulation`.
  939. Note: This parameter is tree-specific.
  940. max_depth : int, default=None
  941. The maximum depth of the tree. If None, then nodes are expanded until
  942. all leaves are pure or until all leaves contain less than
  943. min_samples_split samples.
  944. min_samples_split : int or float, default=2
  945. The minimum number of samples required to split an internal node:
  946. - If int, then consider `min_samples_split` as the minimum number.
  947. - If float, then `min_samples_split` is a fraction and
  948. `ceil(min_samples_split * n_samples)` are the minimum
  949. number of samples for each split.
  950. .. versionchanged:: 0.18
  951. Added float values for fractions.
  952. min_samples_leaf : int or float, default=1
  953. The minimum number of samples required to be at a leaf node.
  954. A split point at any depth will only be considered if it leaves at
  955. least ``min_samples_leaf`` training samples in each of the left and
  956. right branches. This may have the effect of smoothing the model,
  957. especially in regression.
  958. - If int, then consider `min_samples_leaf` as the minimum number.
  959. - If float, then `min_samples_leaf` is a fraction and
  960. `ceil(min_samples_leaf * n_samples)` are the minimum
  961. number of samples for each node.
  962. .. versionchanged:: 0.18
  963. Added float values for fractions.
  964. min_weight_fraction_leaf : float, default=0.0
  965. The minimum weighted fraction of the sum total of weights (of all
  966. the input samples) required to be at a leaf node. Samples have
  967. equal weight when sample_weight is not provided.
  968. max_features : {"sqrt", "log2", None}, int or float, default="sqrt"
  969. The number of features to consider when looking for the best split:
  970. - If int, then consider `max_features` features at each split.
  971. - If float, then `max_features` is a fraction and
  972. `max(1, int(max_features * n_features_in_))` features are considered at each
  973. split.
  974. - If "sqrt", then `max_features=sqrt(n_features)`.
  975. - If "log2", then `max_features=log2(n_features)`.
  976. - If None, then `max_features=n_features`.
  977. .. versionchanged:: 1.1
  978. The default of `max_features` changed from `"auto"` to `"sqrt"`.
  979. Note: the search for a split does not stop until at least one
  980. valid partition of the node samples is found, even if it requires to
  981. effectively inspect more than ``max_features`` features.
  982. max_leaf_nodes : int, default=None
  983. Grow trees with ``max_leaf_nodes`` in best-first fashion.
  984. Best nodes are defined as relative reduction in impurity.
  985. If None then unlimited number of leaf nodes.
  986. min_impurity_decrease : float, default=0.0
  987. A node will be split if this split induces a decrease of the impurity
  988. greater than or equal to this value.
  989. The weighted impurity decrease equation is the following::
  990. N_t / N * (impurity - N_t_R / N_t * right_impurity
  991. - N_t_L / N_t * left_impurity)
  992. where ``N`` is the total number of samples, ``N_t`` is the number of
  993. samples at the current node, ``N_t_L`` is the number of samples in the
  994. left child, and ``N_t_R`` is the number of samples in the right child.
  995. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
  996. if ``sample_weight`` is passed.
  997. .. versionadded:: 0.19
  998. bootstrap : bool, default=True
  999. Whether bootstrap samples are used when building trees. If False, the
  1000. whole dataset is used to build each tree.
  1001. oob_score : bool or callable, default=False
  1002. Whether to use out-of-bag samples to estimate the generalization score.
  1003. By default, :func:`~sklearn.metrics.accuracy_score` is used.
  1004. Provide a callable with signature `metric(y_true, y_pred)` to use a
  1005. custom metric. Only available if `bootstrap=True`.
  1006. n_jobs : int, default=None
  1007. The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
  1008. :meth:`decision_path` and :meth:`apply` are all parallelized over the
  1009. trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
  1010. context. ``-1`` means using all processors. See :term:`Glossary
  1011. <n_jobs>` for more details.
  1012. random_state : int, RandomState instance or None, default=None
  1013. Controls both the randomness of the bootstrapping of the samples used
  1014. when building trees (if ``bootstrap=True``) and the sampling of the
  1015. features to consider when looking for the best split at each node
  1016. (if ``max_features < n_features``).
  1017. See :term:`Glossary <random_state>` for details.
  1018. verbose : int, default=0
  1019. Controls the verbosity when fitting and predicting.
  1020. warm_start : bool, default=False
  1021. When set to ``True``, reuse the solution of the previous call to fit
  1022. and add more estimators to the ensemble, otherwise, just fit a whole
  1023. new forest. See :term:`Glossary <warm_start>` and
  1024. :ref:`gradient_boosting_warm_start` for details.
  1025. class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
  1026. default=None
  1027. Weights associated with classes in the form ``{class_label: weight}``.
  1028. If not given, all classes are supposed to have weight one. For
  1029. multi-output problems, a list of dicts can be provided in the same
  1030. order as the columns of y.
  1031. Note that for multioutput (including multilabel) weights should be
  1032. defined for each class of every column in its own dict. For example,
  1033. for four-class multilabel classification weights should be
  1034. [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
  1035. [{1:1}, {2:5}, {3:1}, {4:1}].
  1036. The "balanced" mode uses the values of y to automatically adjust
  1037. weights inversely proportional to class frequencies in the input data
  1038. as ``n_samples / (n_classes * np.bincount(y))``
  1039. The "balanced_subsample" mode is the same as "balanced" except that
  1040. weights are computed based on the bootstrap sample for every tree
  1041. grown.
  1042. For multi-output, the weights of each column of y will be multiplied.
  1043. Note that these weights will be multiplied with sample_weight (passed
  1044. through the fit method) if sample_weight is specified.
  1045. ccp_alpha : non-negative float, default=0.0
  1046. Complexity parameter used for Minimal Cost-Complexity Pruning. The
  1047. subtree with the largest cost complexity that is smaller than
  1048. ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
  1049. :ref:`minimal_cost_complexity_pruning` for details.
  1050. .. versionadded:: 0.22
  1051. max_samples : int or float, default=None
  1052. If bootstrap is True, the number of samples to draw from X
  1053. to train each base estimator.
  1054. - If None (default), then draw `X.shape[0]` samples.
  1055. - If int, then draw `max_samples` samples.
  1056. - If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus,
  1057. `max_samples` should be in the interval `(0.0, 1.0]`.
  1058. .. versionadded:: 0.22
  1059. Attributes
  1060. ----------
  1061. estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier`
  1062. The child estimator template used to create the collection of fitted
  1063. sub-estimators.
  1064. .. versionadded:: 1.2
  1065. `base_estimator_` was renamed to `estimator_`.
  1066. base_estimator_ : DecisionTreeClassifier
  1067. The child estimator template used to create the collection of fitted
  1068. sub-estimators.
  1069. .. deprecated:: 1.2
  1070. `base_estimator_` is deprecated and will be removed in 1.4.
  1071. Use `estimator_` instead.
  1072. estimators_ : list of DecisionTreeClassifier
  1073. The collection of fitted sub-estimators.
  1074. classes_ : ndarray of shape (n_classes,) or a list of such arrays
  1075. The classes labels (single output problem), or a list of arrays of
  1076. class labels (multi-output problem).
  1077. n_classes_ : int or list
  1078. The number of classes (single output problem), or a list containing the
  1079. number of classes for each output (multi-output problem).
  1080. n_features_in_ : int
  1081. Number of features seen during :term:`fit`.
  1082. .. versionadded:: 0.24
  1083. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1084. Names of features seen during :term:`fit`. Defined only when `X`
  1085. has feature names that are all strings.
  1086. .. versionadded:: 1.0
  1087. n_outputs_ : int
  1088. The number of outputs when ``fit`` is performed.
  1089. feature_importances_ : ndarray of shape (n_features,)
  1090. The impurity-based feature importances.
  1091. The higher, the more important the feature.
  1092. The importance of a feature is computed as the (normalized)
  1093. total reduction of the criterion brought by that feature. It is also
  1094. known as the Gini importance.
  1095. Warning: impurity-based feature importances can be misleading for
  1096. high cardinality features (many unique values). See
  1097. :func:`sklearn.inspection.permutation_importance` as an alternative.
  1098. oob_score_ : float
  1099. Score of the training dataset obtained using an out-of-bag estimate.
  1100. This attribute exists only when ``oob_score`` is True.
  1101. oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \
  1102. (n_samples, n_classes, n_outputs)
  1103. Decision function computed with out-of-bag estimate on the training
  1104. set. If n_estimators is small it might be possible that a data point
  1105. was never left out during the bootstrap. In this case,
  1106. `oob_decision_function_` might contain NaN. This attribute exists
  1107. only when ``oob_score`` is True.
  1108. See Also
  1109. --------
  1110. sklearn.tree.DecisionTreeClassifier : A decision tree classifier.
  1111. sklearn.ensemble.ExtraTreesClassifier : Ensemble of extremely randomized
  1112. tree classifiers.
  1113. sklearn.ensemble.HistGradientBoostingClassifier : A Histogram-based Gradient
  1114. Boosting Classification Tree, very fast for big datasets (n_samples >=
  1115. 10_000).
  1116. Notes
  1117. -----
  1118. The default values for the parameters controlling the size of the trees
  1119. (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
  1120. unpruned trees which can potentially be very large on some data sets. To
  1121. reduce memory consumption, the complexity and size of the trees should be
  1122. controlled by setting those parameter values.
  1123. The features are always randomly permuted at each split. Therefore,
  1124. the best found split may vary, even with the same training data,
  1125. ``max_features=n_features`` and ``bootstrap=False``, if the improvement
  1126. of the criterion is identical for several splits enumerated during the
  1127. search of the best split. To obtain a deterministic behaviour during
  1128. fitting, ``random_state`` has to be fixed.
  1129. References
  1130. ----------
  1131. .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
  1132. Examples
  1133. --------
  1134. >>> from sklearn.ensemble import RandomForestClassifier
  1135. >>> from sklearn.datasets import make_classification
  1136. >>> X, y = make_classification(n_samples=1000, n_features=4,
  1137. ... n_informative=2, n_redundant=0,
  1138. ... random_state=0, shuffle=False)
  1139. >>> clf = RandomForestClassifier(max_depth=2, random_state=0)
  1140. >>> clf.fit(X, y)
  1141. RandomForestClassifier(...)
  1142. >>> print(clf.predict([[0, 0, 0, 0]]))
  1143. [1]
  1144. """
  1145. _parameter_constraints: dict = {
  1146. **ForestClassifier._parameter_constraints,
  1147. **DecisionTreeClassifier._parameter_constraints,
  1148. "class_weight": [
  1149. StrOptions({"balanced_subsample", "balanced"}),
  1150. dict,
  1151. list,
  1152. None,
  1153. ],
  1154. }
  1155. _parameter_constraints.pop("splitter")
  1156. def __init__(
  1157. self,
  1158. n_estimators=100,
  1159. *,
  1160. criterion="gini",
  1161. max_depth=None,
  1162. min_samples_split=2,
  1163. min_samples_leaf=1,
  1164. min_weight_fraction_leaf=0.0,
  1165. max_features="sqrt",
  1166. max_leaf_nodes=None,
  1167. min_impurity_decrease=0.0,
  1168. bootstrap=True,
  1169. oob_score=False,
  1170. n_jobs=None,
  1171. random_state=None,
  1172. verbose=0,
  1173. warm_start=False,
  1174. class_weight=None,
  1175. ccp_alpha=0.0,
  1176. max_samples=None,
  1177. ):
  1178. super().__init__(
  1179. estimator=DecisionTreeClassifier(),
  1180. n_estimators=n_estimators,
  1181. estimator_params=(
  1182. "criterion",
  1183. "max_depth",
  1184. "min_samples_split",
  1185. "min_samples_leaf",
  1186. "min_weight_fraction_leaf",
  1187. "max_features",
  1188. "max_leaf_nodes",
  1189. "min_impurity_decrease",
  1190. "random_state",
  1191. "ccp_alpha",
  1192. ),
  1193. bootstrap=bootstrap,
  1194. oob_score=oob_score,
  1195. n_jobs=n_jobs,
  1196. random_state=random_state,
  1197. verbose=verbose,
  1198. warm_start=warm_start,
  1199. class_weight=class_weight,
  1200. max_samples=max_samples,
  1201. )
  1202. self.criterion = criterion
  1203. self.max_depth = max_depth
  1204. self.min_samples_split = min_samples_split
  1205. self.min_samples_leaf = min_samples_leaf
  1206. self.min_weight_fraction_leaf = min_weight_fraction_leaf
  1207. self.max_features = max_features
  1208. self.max_leaf_nodes = max_leaf_nodes
  1209. self.min_impurity_decrease = min_impurity_decrease
  1210. self.ccp_alpha = ccp_alpha
  1211. class RandomForestRegressor(ForestRegressor):
  1212. """
  1213. A random forest regressor.
  1214. A random forest is a meta estimator that fits a number of classifying
  1215. decision trees on various sub-samples of the dataset and uses averaging
  1216. to improve the predictive accuracy and control over-fitting.
  1217. The sub-sample size is controlled with the `max_samples` parameter if
  1218. `bootstrap=True` (default), otherwise the whole dataset is used to build
  1219. each tree.
  1220. For a comparison between tree-based ensemble models see the example
  1221. :ref:`sphx_glr_auto_examples_ensemble_plot_forest_hist_grad_boosting_comparison.py`.
  1222. Read more in the :ref:`User Guide <forest>`.
  1223. Parameters
  1224. ----------
  1225. n_estimators : int, default=100
  1226. The number of trees in the forest.
  1227. .. versionchanged:: 0.22
  1228. The default value of ``n_estimators`` changed from 10 to 100
  1229. in 0.22.
  1230. criterion : {"squared_error", "absolute_error", "friedman_mse", "poisson"}, \
  1231. default="squared_error"
  1232. The function to measure the quality of a split. Supported criteria
  1233. are "squared_error" for the mean squared error, which is equal to
  1234. variance reduction as feature selection criterion and minimizes the L2
  1235. loss using the mean of each terminal node, "friedman_mse", which uses
  1236. mean squared error with Friedman's improvement score for potential
  1237. splits, "absolute_error" for the mean absolute error, which minimizes
  1238. the L1 loss using the median of each terminal node, and "poisson" which
  1239. uses reduction in Poisson deviance to find splits.
  1240. Training using "absolute_error" is significantly slower
  1241. than when using "squared_error".
  1242. .. versionadded:: 0.18
  1243. Mean Absolute Error (MAE) criterion.
  1244. .. versionadded:: 1.0
  1245. Poisson criterion.
  1246. max_depth : int, default=None
  1247. The maximum depth of the tree. If None, then nodes are expanded until
  1248. all leaves are pure or until all leaves contain less than
  1249. min_samples_split samples.
  1250. min_samples_split : int or float, default=2
  1251. The minimum number of samples required to split an internal node:
  1252. - If int, then consider `min_samples_split` as the minimum number.
  1253. - If float, then `min_samples_split` is a fraction and
  1254. `ceil(min_samples_split * n_samples)` are the minimum
  1255. number of samples for each split.
  1256. .. versionchanged:: 0.18
  1257. Added float values for fractions.
  1258. min_samples_leaf : int or float, default=1
  1259. The minimum number of samples required to be at a leaf node.
  1260. A split point at any depth will only be considered if it leaves at
  1261. least ``min_samples_leaf`` training samples in each of the left and
  1262. right branches. This may have the effect of smoothing the model,
  1263. especially in regression.
  1264. - If int, then consider `min_samples_leaf` as the minimum number.
  1265. - If float, then `min_samples_leaf` is a fraction and
  1266. `ceil(min_samples_leaf * n_samples)` are the minimum
  1267. number of samples for each node.
  1268. .. versionchanged:: 0.18
  1269. Added float values for fractions.
  1270. min_weight_fraction_leaf : float, default=0.0
  1271. The minimum weighted fraction of the sum total of weights (of all
  1272. the input samples) required to be at a leaf node. Samples have
  1273. equal weight when sample_weight is not provided.
  1274. max_features : {"sqrt", "log2", None}, int or float, default=1.0
  1275. The number of features to consider when looking for the best split:
  1276. - If int, then consider `max_features` features at each split.
  1277. - If float, then `max_features` is a fraction and
  1278. `max(1, int(max_features * n_features_in_))` features are considered at each
  1279. split.
  1280. - If "sqrt", then `max_features=sqrt(n_features)`.
  1281. - If "log2", then `max_features=log2(n_features)`.
  1282. - If None or 1.0, then `max_features=n_features`.
  1283. .. note::
  1284. The default of 1.0 is equivalent to bagged trees and more
  1285. randomness can be achieved by setting smaller values, e.g. 0.3.
  1286. .. versionchanged:: 1.1
  1287. The default of `max_features` changed from `"auto"` to 1.0.
  1288. Note: the search for a split does not stop until at least one
  1289. valid partition of the node samples is found, even if it requires to
  1290. effectively inspect more than ``max_features`` features.
  1291. max_leaf_nodes : int, default=None
  1292. Grow trees with ``max_leaf_nodes`` in best-first fashion.
  1293. Best nodes are defined as relative reduction in impurity.
  1294. If None then unlimited number of leaf nodes.
  1295. min_impurity_decrease : float, default=0.0
  1296. A node will be split if this split induces a decrease of the impurity
  1297. greater than or equal to this value.
  1298. The weighted impurity decrease equation is the following::
  1299. N_t / N * (impurity - N_t_R / N_t * right_impurity
  1300. - N_t_L / N_t * left_impurity)
  1301. where ``N`` is the total number of samples, ``N_t`` is the number of
  1302. samples at the current node, ``N_t_L`` is the number of samples in the
  1303. left child, and ``N_t_R`` is the number of samples in the right child.
  1304. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
  1305. if ``sample_weight`` is passed.
  1306. .. versionadded:: 0.19
  1307. bootstrap : bool, default=True
  1308. Whether bootstrap samples are used when building trees. If False, the
  1309. whole dataset is used to build each tree.
  1310. oob_score : bool or callable, default=False
  1311. Whether to use out-of-bag samples to estimate the generalization score.
  1312. By default, :func:`~sklearn.metrics.r2_score` is used.
  1313. Provide a callable with signature `metric(y_true, y_pred)` to use a
  1314. custom metric. Only available if `bootstrap=True`.
  1315. n_jobs : int, default=None
  1316. The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
  1317. :meth:`decision_path` and :meth:`apply` are all parallelized over the
  1318. trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
  1319. context. ``-1`` means using all processors. See :term:`Glossary
  1320. <n_jobs>` for more details.
  1321. random_state : int, RandomState instance or None, default=None
  1322. Controls both the randomness of the bootstrapping of the samples used
  1323. when building trees (if ``bootstrap=True``) and the sampling of the
  1324. features to consider when looking for the best split at each node
  1325. (if ``max_features < n_features``).
  1326. See :term:`Glossary <random_state>` for details.
  1327. verbose : int, default=0
  1328. Controls the verbosity when fitting and predicting.
  1329. warm_start : bool, default=False
  1330. When set to ``True``, reuse the solution of the previous call to fit
  1331. and add more estimators to the ensemble, otherwise, just fit a whole
  1332. new forest. See :term:`Glossary <warm_start>` and
  1333. :ref:`gradient_boosting_warm_start` for details.
  1334. ccp_alpha : non-negative float, default=0.0
  1335. Complexity parameter used for Minimal Cost-Complexity Pruning. The
  1336. subtree with the largest cost complexity that is smaller than
  1337. ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
  1338. :ref:`minimal_cost_complexity_pruning` for details.
  1339. .. versionadded:: 0.22
  1340. max_samples : int or float, default=None
  1341. If bootstrap is True, the number of samples to draw from X
  1342. to train each base estimator.
  1343. - If None (default), then draw `X.shape[0]` samples.
  1344. - If int, then draw `max_samples` samples.
  1345. - If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus,
  1346. `max_samples` should be in the interval `(0.0, 1.0]`.
  1347. .. versionadded:: 0.22
  1348. Attributes
  1349. ----------
  1350. estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
  1351. The child estimator template used to create the collection of fitted
  1352. sub-estimators.
  1353. .. versionadded:: 1.2
  1354. `base_estimator_` was renamed to `estimator_`.
  1355. base_estimator_ : DecisionTreeRegressor
  1356. The child estimator template used to create the collection of fitted
  1357. sub-estimators.
  1358. .. deprecated:: 1.2
  1359. `base_estimator_` is deprecated and will be removed in 1.4.
  1360. Use `estimator_` instead.
  1361. estimators_ : list of DecisionTreeRegressor
  1362. The collection of fitted sub-estimators.
  1363. feature_importances_ : ndarray of shape (n_features,)
  1364. The impurity-based feature importances.
  1365. The higher, the more important the feature.
  1366. The importance of a feature is computed as the (normalized)
  1367. total reduction of the criterion brought by that feature. It is also
  1368. known as the Gini importance.
  1369. Warning: impurity-based feature importances can be misleading for
  1370. high cardinality features (many unique values). See
  1371. :func:`sklearn.inspection.permutation_importance` as an alternative.
  1372. n_features_in_ : int
  1373. Number of features seen during :term:`fit`.
  1374. .. versionadded:: 0.24
  1375. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1376. Names of features seen during :term:`fit`. Defined only when `X`
  1377. has feature names that are all strings.
  1378. .. versionadded:: 1.0
  1379. n_outputs_ : int
  1380. The number of outputs when ``fit`` is performed.
  1381. oob_score_ : float
  1382. Score of the training dataset obtained using an out-of-bag estimate.
  1383. This attribute exists only when ``oob_score`` is True.
  1384. oob_prediction_ : ndarray of shape (n_samples,) or (n_samples, n_outputs)
  1385. Prediction computed with out-of-bag estimate on the training set.
  1386. This attribute exists only when ``oob_score`` is True.
  1387. See Also
  1388. --------
  1389. sklearn.tree.DecisionTreeRegressor : A decision tree regressor.
  1390. sklearn.ensemble.ExtraTreesRegressor : Ensemble of extremely randomized
  1391. tree regressors.
  1392. sklearn.ensemble.HistGradientBoostingRegressor : A Histogram-based Gradient
  1393. Boosting Regression Tree, very fast for big datasets (n_samples >=
  1394. 10_000).
  1395. Notes
  1396. -----
  1397. The default values for the parameters controlling the size of the trees
  1398. (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
  1399. unpruned trees which can potentially be very large on some data sets. To
  1400. reduce memory consumption, the complexity and size of the trees should be
  1401. controlled by setting those parameter values.
  1402. The features are always randomly permuted at each split. Therefore,
  1403. the best found split may vary, even with the same training data,
  1404. ``max_features=n_features`` and ``bootstrap=False``, if the improvement
  1405. of the criterion is identical for several splits enumerated during the
  1406. search of the best split. To obtain a deterministic behaviour during
  1407. fitting, ``random_state`` has to be fixed.
  1408. The default value ``max_features=1.0`` uses ``n_features``
  1409. rather than ``n_features / 3``. The latter was originally suggested in
  1410. [1], whereas the former was more recently justified empirically in [2].
  1411. References
  1412. ----------
  1413. .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001.
  1414. .. [2] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized
  1415. trees", Machine Learning, 63(1), 3-42, 2006.
  1416. Examples
  1417. --------
  1418. >>> from sklearn.ensemble import RandomForestRegressor
  1419. >>> from sklearn.datasets import make_regression
  1420. >>> X, y = make_regression(n_features=4, n_informative=2,
  1421. ... random_state=0, shuffle=False)
  1422. >>> regr = RandomForestRegressor(max_depth=2, random_state=0)
  1423. >>> regr.fit(X, y)
  1424. RandomForestRegressor(...)
  1425. >>> print(regr.predict([[0, 0, 0, 0]]))
  1426. [-8.32987858]
  1427. """
  1428. _parameter_constraints: dict = {
  1429. **ForestRegressor._parameter_constraints,
  1430. **DecisionTreeRegressor._parameter_constraints,
  1431. }
  1432. _parameter_constraints.pop("splitter")
  1433. def __init__(
  1434. self,
  1435. n_estimators=100,
  1436. *,
  1437. criterion="squared_error",
  1438. max_depth=None,
  1439. min_samples_split=2,
  1440. min_samples_leaf=1,
  1441. min_weight_fraction_leaf=0.0,
  1442. max_features=1.0,
  1443. max_leaf_nodes=None,
  1444. min_impurity_decrease=0.0,
  1445. bootstrap=True,
  1446. oob_score=False,
  1447. n_jobs=None,
  1448. random_state=None,
  1449. verbose=0,
  1450. warm_start=False,
  1451. ccp_alpha=0.0,
  1452. max_samples=None,
  1453. ):
  1454. super().__init__(
  1455. estimator=DecisionTreeRegressor(),
  1456. n_estimators=n_estimators,
  1457. estimator_params=(
  1458. "criterion",
  1459. "max_depth",
  1460. "min_samples_split",
  1461. "min_samples_leaf",
  1462. "min_weight_fraction_leaf",
  1463. "max_features",
  1464. "max_leaf_nodes",
  1465. "min_impurity_decrease",
  1466. "random_state",
  1467. "ccp_alpha",
  1468. ),
  1469. bootstrap=bootstrap,
  1470. oob_score=oob_score,
  1471. n_jobs=n_jobs,
  1472. random_state=random_state,
  1473. verbose=verbose,
  1474. warm_start=warm_start,
  1475. max_samples=max_samples,
  1476. )
  1477. self.criterion = criterion
  1478. self.max_depth = max_depth
  1479. self.min_samples_split = min_samples_split
  1480. self.min_samples_leaf = min_samples_leaf
  1481. self.min_weight_fraction_leaf = min_weight_fraction_leaf
  1482. self.max_features = max_features
  1483. self.max_leaf_nodes = max_leaf_nodes
  1484. self.min_impurity_decrease = min_impurity_decrease
  1485. self.ccp_alpha = ccp_alpha
  1486. class ExtraTreesClassifier(ForestClassifier):
  1487. """
  1488. An extra-trees classifier.
  1489. This class implements a meta estimator that fits a number of
  1490. randomized decision trees (a.k.a. extra-trees) on various sub-samples
  1491. of the dataset and uses averaging to improve the predictive accuracy
  1492. and control over-fitting.
  1493. Read more in the :ref:`User Guide <forest>`.
  1494. Parameters
  1495. ----------
  1496. n_estimators : int, default=100
  1497. The number of trees in the forest.
  1498. .. versionchanged:: 0.22
  1499. The default value of ``n_estimators`` changed from 10 to 100
  1500. in 0.22.
  1501. criterion : {"gini", "entropy", "log_loss"}, default="gini"
  1502. The function to measure the quality of a split. Supported criteria are
  1503. "gini" for the Gini impurity and "log_loss" and "entropy" both for the
  1504. Shannon information gain, see :ref:`tree_mathematical_formulation`.
  1505. Note: This parameter is tree-specific.
  1506. max_depth : int, default=None
  1507. The maximum depth of the tree. If None, then nodes are expanded until
  1508. all leaves are pure or until all leaves contain less than
  1509. min_samples_split samples.
  1510. min_samples_split : int or float, default=2
  1511. The minimum number of samples required to split an internal node:
  1512. - If int, then consider `min_samples_split` as the minimum number.
  1513. - If float, then `min_samples_split` is a fraction and
  1514. `ceil(min_samples_split * n_samples)` are the minimum
  1515. number of samples for each split.
  1516. .. versionchanged:: 0.18
  1517. Added float values for fractions.
  1518. min_samples_leaf : int or float, default=1
  1519. The minimum number of samples required to be at a leaf node.
  1520. A split point at any depth will only be considered if it leaves at
  1521. least ``min_samples_leaf`` training samples in each of the left and
  1522. right branches. This may have the effect of smoothing the model,
  1523. especially in regression.
  1524. - If int, then consider `min_samples_leaf` as the minimum number.
  1525. - If float, then `min_samples_leaf` is a fraction and
  1526. `ceil(min_samples_leaf * n_samples)` are the minimum
  1527. number of samples for each node.
  1528. .. versionchanged:: 0.18
  1529. Added float values for fractions.
  1530. min_weight_fraction_leaf : float, default=0.0
  1531. The minimum weighted fraction of the sum total of weights (of all
  1532. the input samples) required to be at a leaf node. Samples have
  1533. equal weight when sample_weight is not provided.
  1534. max_features : {"sqrt", "log2", None}, int or float, default="sqrt"
  1535. The number of features to consider when looking for the best split:
  1536. - If int, then consider `max_features` features at each split.
  1537. - If float, then `max_features` is a fraction and
  1538. `max(1, int(max_features * n_features_in_))` features are considered at each
  1539. split.
  1540. - If "sqrt", then `max_features=sqrt(n_features)`.
  1541. - If "log2", then `max_features=log2(n_features)`.
  1542. - If None, then `max_features=n_features`.
  1543. .. versionchanged:: 1.1
  1544. The default of `max_features` changed from `"auto"` to `"sqrt"`.
  1545. Note: the search for a split does not stop until at least one
  1546. valid partition of the node samples is found, even if it requires to
  1547. effectively inspect more than ``max_features`` features.
  1548. max_leaf_nodes : int, default=None
  1549. Grow trees with ``max_leaf_nodes`` in best-first fashion.
  1550. Best nodes are defined as relative reduction in impurity.
  1551. If None then unlimited number of leaf nodes.
  1552. min_impurity_decrease : float, default=0.0
  1553. A node will be split if this split induces a decrease of the impurity
  1554. greater than or equal to this value.
  1555. The weighted impurity decrease equation is the following::
  1556. N_t / N * (impurity - N_t_R / N_t * right_impurity
  1557. - N_t_L / N_t * left_impurity)
  1558. where ``N`` is the total number of samples, ``N_t`` is the number of
  1559. samples at the current node, ``N_t_L`` is the number of samples in the
  1560. left child, and ``N_t_R`` is the number of samples in the right child.
  1561. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
  1562. if ``sample_weight`` is passed.
  1563. .. versionadded:: 0.19
  1564. bootstrap : bool, default=False
  1565. Whether bootstrap samples are used when building trees. If False, the
  1566. whole dataset is used to build each tree.
  1567. oob_score : bool or callable, default=False
  1568. Whether to use out-of-bag samples to estimate the generalization score.
  1569. By default, :func:`~sklearn.metrics.accuracy_score` is used.
  1570. Provide a callable with signature `metric(y_true, y_pred)` to use a
  1571. custom metric. Only available if `bootstrap=True`.
  1572. n_jobs : int, default=None
  1573. The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
  1574. :meth:`decision_path` and :meth:`apply` are all parallelized over the
  1575. trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
  1576. context. ``-1`` means using all processors. See :term:`Glossary
  1577. <n_jobs>` for more details.
  1578. random_state : int, RandomState instance or None, default=None
  1579. Controls 3 sources of randomness:
  1580. - the bootstrapping of the samples used when building trees
  1581. (if ``bootstrap=True``)
  1582. - the sampling of the features to consider when looking for the best
  1583. split at each node (if ``max_features < n_features``)
  1584. - the draw of the splits for each of the `max_features`
  1585. See :term:`Glossary <random_state>` for details.
  1586. verbose : int, default=0
  1587. Controls the verbosity when fitting and predicting.
  1588. warm_start : bool, default=False
  1589. When set to ``True``, reuse the solution of the previous call to fit
  1590. and add more estimators to the ensemble, otherwise, just fit a whole
  1591. new forest. See :term:`Glossary <warm_start>` and
  1592. :ref:`gradient_boosting_warm_start` for details.
  1593. class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
  1594. default=None
  1595. Weights associated with classes in the form ``{class_label: weight}``.
  1596. If not given, all classes are supposed to have weight one. For
  1597. multi-output problems, a list of dicts can be provided in the same
  1598. order as the columns of y.
  1599. Note that for multioutput (including multilabel) weights should be
  1600. defined for each class of every column in its own dict. For example,
  1601. for four-class multilabel classification weights should be
  1602. [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
  1603. [{1:1}, {2:5}, {3:1}, {4:1}].
  1604. The "balanced" mode uses the values of y to automatically adjust
  1605. weights inversely proportional to class frequencies in the input data
  1606. as ``n_samples / (n_classes * np.bincount(y))``
  1607. The "balanced_subsample" mode is the same as "balanced" except that
  1608. weights are computed based on the bootstrap sample for every tree
  1609. grown.
  1610. For multi-output, the weights of each column of y will be multiplied.
  1611. Note that these weights will be multiplied with sample_weight (passed
  1612. through the fit method) if sample_weight is specified.
  1613. ccp_alpha : non-negative float, default=0.0
  1614. Complexity parameter used for Minimal Cost-Complexity Pruning. The
  1615. subtree with the largest cost complexity that is smaller than
  1616. ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
  1617. :ref:`minimal_cost_complexity_pruning` for details.
  1618. .. versionadded:: 0.22
  1619. max_samples : int or float, default=None
  1620. If bootstrap is True, the number of samples to draw from X
  1621. to train each base estimator.
  1622. - If None (default), then draw `X.shape[0]` samples.
  1623. - If int, then draw `max_samples` samples.
  1624. - If float, then draw `max_samples * X.shape[0]` samples. Thus,
  1625. `max_samples` should be in the interval `(0.0, 1.0]`.
  1626. .. versionadded:: 0.22
  1627. Attributes
  1628. ----------
  1629. estimator_ : :class:`~sklearn.tree.ExtraTreeClassifier`
  1630. The child estimator template used to create the collection of fitted
  1631. sub-estimators.
  1632. .. versionadded:: 1.2
  1633. `base_estimator_` was renamed to `estimator_`.
  1634. base_estimator_ : ExtraTreesClassifier
  1635. The child estimator template used to create the collection of fitted
  1636. sub-estimators.
  1637. .. deprecated:: 1.2
  1638. `base_estimator_` is deprecated and will be removed in 1.4.
  1639. Use `estimator_` instead.
  1640. estimators_ : list of DecisionTreeClassifier
  1641. The collection of fitted sub-estimators.
  1642. classes_ : ndarray of shape (n_classes,) or a list of such arrays
  1643. The classes labels (single output problem), or a list of arrays of
  1644. class labels (multi-output problem).
  1645. n_classes_ : int or list
  1646. The number of classes (single output problem), or a list containing the
  1647. number of classes for each output (multi-output problem).
  1648. feature_importances_ : ndarray of shape (n_features,)
  1649. The impurity-based feature importances.
  1650. The higher, the more important the feature.
  1651. The importance of a feature is computed as the (normalized)
  1652. total reduction of the criterion brought by that feature. It is also
  1653. known as the Gini importance.
  1654. Warning: impurity-based feature importances can be misleading for
  1655. high cardinality features (many unique values). See
  1656. :func:`sklearn.inspection.permutation_importance` as an alternative.
  1657. n_features_in_ : int
  1658. Number of features seen during :term:`fit`.
  1659. .. versionadded:: 0.24
  1660. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1661. Names of features seen during :term:`fit`. Defined only when `X`
  1662. has feature names that are all strings.
  1663. .. versionadded:: 1.0
  1664. n_outputs_ : int
  1665. The number of outputs when ``fit`` is performed.
  1666. oob_score_ : float
  1667. Score of the training dataset obtained using an out-of-bag estimate.
  1668. This attribute exists only when ``oob_score`` is True.
  1669. oob_decision_function_ : ndarray of shape (n_samples, n_classes) or \
  1670. (n_samples, n_classes, n_outputs)
  1671. Decision function computed with out-of-bag estimate on the training
  1672. set. If n_estimators is small it might be possible that a data point
  1673. was never left out during the bootstrap. In this case,
  1674. `oob_decision_function_` might contain NaN. This attribute exists
  1675. only when ``oob_score`` is True.
  1676. See Also
  1677. --------
  1678. ExtraTreesRegressor : An extra-trees regressor with random splits.
  1679. RandomForestClassifier : A random forest classifier with optimal splits.
  1680. RandomForestRegressor : Ensemble regressor using trees with optimal splits.
  1681. Notes
  1682. -----
  1683. The default values for the parameters controlling the size of the trees
  1684. (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
  1685. unpruned trees which can potentially be very large on some data sets. To
  1686. reduce memory consumption, the complexity and size of the trees should be
  1687. controlled by setting those parameter values.
  1688. References
  1689. ----------
  1690. .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized
  1691. trees", Machine Learning, 63(1), 3-42, 2006.
  1692. Examples
  1693. --------
  1694. >>> from sklearn.ensemble import ExtraTreesClassifier
  1695. >>> from sklearn.datasets import make_classification
  1696. >>> X, y = make_classification(n_features=4, random_state=0)
  1697. >>> clf = ExtraTreesClassifier(n_estimators=100, random_state=0)
  1698. >>> clf.fit(X, y)
  1699. ExtraTreesClassifier(random_state=0)
  1700. >>> clf.predict([[0, 0, 0, 0]])
  1701. array([1])
  1702. """
  1703. _parameter_constraints: dict = {
  1704. **ForestClassifier._parameter_constraints,
  1705. **DecisionTreeClassifier._parameter_constraints,
  1706. "class_weight": [
  1707. StrOptions({"balanced_subsample", "balanced"}),
  1708. dict,
  1709. list,
  1710. None,
  1711. ],
  1712. }
  1713. _parameter_constraints.pop("splitter")
  1714. def __init__(
  1715. self,
  1716. n_estimators=100,
  1717. *,
  1718. criterion="gini",
  1719. max_depth=None,
  1720. min_samples_split=2,
  1721. min_samples_leaf=1,
  1722. min_weight_fraction_leaf=0.0,
  1723. max_features="sqrt",
  1724. max_leaf_nodes=None,
  1725. min_impurity_decrease=0.0,
  1726. bootstrap=False,
  1727. oob_score=False,
  1728. n_jobs=None,
  1729. random_state=None,
  1730. verbose=0,
  1731. warm_start=False,
  1732. class_weight=None,
  1733. ccp_alpha=0.0,
  1734. max_samples=None,
  1735. ):
  1736. super().__init__(
  1737. estimator=ExtraTreeClassifier(),
  1738. n_estimators=n_estimators,
  1739. estimator_params=(
  1740. "criterion",
  1741. "max_depth",
  1742. "min_samples_split",
  1743. "min_samples_leaf",
  1744. "min_weight_fraction_leaf",
  1745. "max_features",
  1746. "max_leaf_nodes",
  1747. "min_impurity_decrease",
  1748. "random_state",
  1749. "ccp_alpha",
  1750. ),
  1751. bootstrap=bootstrap,
  1752. oob_score=oob_score,
  1753. n_jobs=n_jobs,
  1754. random_state=random_state,
  1755. verbose=verbose,
  1756. warm_start=warm_start,
  1757. class_weight=class_weight,
  1758. max_samples=max_samples,
  1759. )
  1760. self.criterion = criterion
  1761. self.max_depth = max_depth
  1762. self.min_samples_split = min_samples_split
  1763. self.min_samples_leaf = min_samples_leaf
  1764. self.min_weight_fraction_leaf = min_weight_fraction_leaf
  1765. self.max_features = max_features
  1766. self.max_leaf_nodes = max_leaf_nodes
  1767. self.min_impurity_decrease = min_impurity_decrease
  1768. self.ccp_alpha = ccp_alpha
  1769. class ExtraTreesRegressor(ForestRegressor):
  1770. """
  1771. An extra-trees regressor.
  1772. This class implements a meta estimator that fits a number of
  1773. randomized decision trees (a.k.a. extra-trees) on various sub-samples
  1774. of the dataset and uses averaging to improve the predictive accuracy
  1775. and control over-fitting.
  1776. Read more in the :ref:`User Guide <forest>`.
  1777. Parameters
  1778. ----------
  1779. n_estimators : int, default=100
  1780. The number of trees in the forest.
  1781. .. versionchanged:: 0.22
  1782. The default value of ``n_estimators`` changed from 10 to 100
  1783. in 0.22.
  1784. criterion : {"squared_error", "absolute_error", "friedman_mse", "poisson"}, \
  1785. default="squared_error"
  1786. The function to measure the quality of a split. Supported criteria
  1787. are "squared_error" for the mean squared error, which is equal to
  1788. variance reduction as feature selection criterion and minimizes the L2
  1789. loss using the mean of each terminal node, "friedman_mse", which uses
  1790. mean squared error with Friedman's improvement score for potential
  1791. splits, "absolute_error" for the mean absolute error, which minimizes
  1792. the L1 loss using the median of each terminal node, and "poisson" which
  1793. uses reduction in Poisson deviance to find splits.
  1794. Training using "absolute_error" is significantly slower
  1795. than when using "squared_error".
  1796. .. versionadded:: 0.18
  1797. Mean Absolute Error (MAE) criterion.
  1798. max_depth : int, default=None
  1799. The maximum depth of the tree. If None, then nodes are expanded until
  1800. all leaves are pure or until all leaves contain less than
  1801. min_samples_split samples.
  1802. min_samples_split : int or float, default=2
  1803. The minimum number of samples required to split an internal node:
  1804. - If int, then consider `min_samples_split` as the minimum number.
  1805. - If float, then `min_samples_split` is a fraction and
  1806. `ceil(min_samples_split * n_samples)` are the minimum
  1807. number of samples for each split.
  1808. .. versionchanged:: 0.18
  1809. Added float values for fractions.
  1810. min_samples_leaf : int or float, default=1
  1811. The minimum number of samples required to be at a leaf node.
  1812. A split point at any depth will only be considered if it leaves at
  1813. least ``min_samples_leaf`` training samples in each of the left and
  1814. right branches. This may have the effect of smoothing the model,
  1815. especially in regression.
  1816. - If int, then consider `min_samples_leaf` as the minimum number.
  1817. - If float, then `min_samples_leaf` is a fraction and
  1818. `ceil(min_samples_leaf * n_samples)` are the minimum
  1819. number of samples for each node.
  1820. .. versionchanged:: 0.18
  1821. Added float values for fractions.
  1822. min_weight_fraction_leaf : float, default=0.0
  1823. The minimum weighted fraction of the sum total of weights (of all
  1824. the input samples) required to be at a leaf node. Samples have
  1825. equal weight when sample_weight is not provided.
  1826. max_features : {"sqrt", "log2", None}, int or float, default=1.0
  1827. The number of features to consider when looking for the best split:
  1828. - If int, then consider `max_features` features at each split.
  1829. - If float, then `max_features` is a fraction and
  1830. `max(1, int(max_features * n_features_in_))` features are considered at each
  1831. split.
  1832. - If "sqrt", then `max_features=sqrt(n_features)`.
  1833. - If "log2", then `max_features=log2(n_features)`.
  1834. - If None or 1.0, then `max_features=n_features`.
  1835. .. note::
  1836. The default of 1.0 is equivalent to bagged trees and more
  1837. randomness can be achieved by setting smaller values, e.g. 0.3.
  1838. .. versionchanged:: 1.1
  1839. The default of `max_features` changed from `"auto"` to 1.0.
  1840. Note: the search for a split does not stop until at least one
  1841. valid partition of the node samples is found, even if it requires to
  1842. effectively inspect more than ``max_features`` features.
  1843. max_leaf_nodes : int, default=None
  1844. Grow trees with ``max_leaf_nodes`` in best-first fashion.
  1845. Best nodes are defined as relative reduction in impurity.
  1846. If None then unlimited number of leaf nodes.
  1847. min_impurity_decrease : float, default=0.0
  1848. A node will be split if this split induces a decrease of the impurity
  1849. greater than or equal to this value.
  1850. The weighted impurity decrease equation is the following::
  1851. N_t / N * (impurity - N_t_R / N_t * right_impurity
  1852. - N_t_L / N_t * left_impurity)
  1853. where ``N`` is the total number of samples, ``N_t`` is the number of
  1854. samples at the current node, ``N_t_L`` is the number of samples in the
  1855. left child, and ``N_t_R`` is the number of samples in the right child.
  1856. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
  1857. if ``sample_weight`` is passed.
  1858. .. versionadded:: 0.19
  1859. bootstrap : bool, default=False
  1860. Whether bootstrap samples are used when building trees. If False, the
  1861. whole dataset is used to build each tree.
  1862. oob_score : bool or callable, default=False
  1863. Whether to use out-of-bag samples to estimate the generalization score.
  1864. By default, :func:`~sklearn.metrics.r2_score` is used.
  1865. Provide a callable with signature `metric(y_true, y_pred)` to use a
  1866. custom metric. Only available if `bootstrap=True`.
  1867. n_jobs : int, default=None
  1868. The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
  1869. :meth:`decision_path` and :meth:`apply` are all parallelized over the
  1870. trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
  1871. context. ``-1`` means using all processors. See :term:`Glossary
  1872. <n_jobs>` for more details.
  1873. random_state : int, RandomState instance or None, default=None
  1874. Controls 3 sources of randomness:
  1875. - the bootstrapping of the samples used when building trees
  1876. (if ``bootstrap=True``)
  1877. - the sampling of the features to consider when looking for the best
  1878. split at each node (if ``max_features < n_features``)
  1879. - the draw of the splits for each of the `max_features`
  1880. See :term:`Glossary <random_state>` for details.
  1881. verbose : int, default=0
  1882. Controls the verbosity when fitting and predicting.
  1883. warm_start : bool, default=False
  1884. When set to ``True``, reuse the solution of the previous call to fit
  1885. and add more estimators to the ensemble, otherwise, just fit a whole
  1886. new forest. See :term:`Glossary <warm_start>` and
  1887. :ref:`gradient_boosting_warm_start` for details.
  1888. ccp_alpha : non-negative float, default=0.0
  1889. Complexity parameter used for Minimal Cost-Complexity Pruning. The
  1890. subtree with the largest cost complexity that is smaller than
  1891. ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
  1892. :ref:`minimal_cost_complexity_pruning` for details.
  1893. .. versionadded:: 0.22
  1894. max_samples : int or float, default=None
  1895. If bootstrap is True, the number of samples to draw from X
  1896. to train each base estimator.
  1897. - If None (default), then draw `X.shape[0]` samples.
  1898. - If int, then draw `max_samples` samples.
  1899. - If float, then draw `max_samples * X.shape[0]` samples. Thus,
  1900. `max_samples` should be in the interval `(0.0, 1.0]`.
  1901. .. versionadded:: 0.22
  1902. Attributes
  1903. ----------
  1904. estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
  1905. The child estimator template used to create the collection of fitted
  1906. sub-estimators.
  1907. .. versionadded:: 1.2
  1908. `base_estimator_` was renamed to `estimator_`.
  1909. base_estimator_ : ExtraTreeRegressor
  1910. The child estimator template used to create the collection of fitted
  1911. sub-estimators.
  1912. .. deprecated:: 1.2
  1913. `base_estimator_` is deprecated and will be removed in 1.4.
  1914. Use `estimator_` instead.
  1915. estimators_ : list of DecisionTreeRegressor
  1916. The collection of fitted sub-estimators.
  1917. feature_importances_ : ndarray of shape (n_features,)
  1918. The impurity-based feature importances.
  1919. The higher, the more important the feature.
  1920. The importance of a feature is computed as the (normalized)
  1921. total reduction of the criterion brought by that feature. It is also
  1922. known as the Gini importance.
  1923. Warning: impurity-based feature importances can be misleading for
  1924. high cardinality features (many unique values). See
  1925. :func:`sklearn.inspection.permutation_importance` as an alternative.
  1926. n_features_in_ : int
  1927. Number of features seen during :term:`fit`.
  1928. .. versionadded:: 0.24
  1929. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1930. Names of features seen during :term:`fit`. Defined only when `X`
  1931. has feature names that are all strings.
  1932. .. versionadded:: 1.0
  1933. n_outputs_ : int
  1934. The number of outputs.
  1935. oob_score_ : float
  1936. Score of the training dataset obtained using an out-of-bag estimate.
  1937. This attribute exists only when ``oob_score`` is True.
  1938. oob_prediction_ : ndarray of shape (n_samples,) or (n_samples, n_outputs)
  1939. Prediction computed with out-of-bag estimate on the training set.
  1940. This attribute exists only when ``oob_score`` is True.
  1941. See Also
  1942. --------
  1943. ExtraTreesClassifier : An extra-trees classifier with random splits.
  1944. RandomForestClassifier : A random forest classifier with optimal splits.
  1945. RandomForestRegressor : Ensemble regressor using trees with optimal splits.
  1946. Notes
  1947. -----
  1948. The default values for the parameters controlling the size of the trees
  1949. (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
  1950. unpruned trees which can potentially be very large on some data sets. To
  1951. reduce memory consumption, the complexity and size of the trees should be
  1952. controlled by setting those parameter values.
  1953. References
  1954. ----------
  1955. .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
  1956. Machine Learning, 63(1), 3-42, 2006.
  1957. Examples
  1958. --------
  1959. >>> from sklearn.datasets import load_diabetes
  1960. >>> from sklearn.model_selection import train_test_split
  1961. >>> from sklearn.ensemble import ExtraTreesRegressor
  1962. >>> X, y = load_diabetes(return_X_y=True)
  1963. >>> X_train, X_test, y_train, y_test = train_test_split(
  1964. ... X, y, random_state=0)
  1965. >>> reg = ExtraTreesRegressor(n_estimators=100, random_state=0).fit(
  1966. ... X_train, y_train)
  1967. >>> reg.score(X_test, y_test)
  1968. 0.2727...
  1969. """
  1970. _parameter_constraints: dict = {
  1971. **ForestRegressor._parameter_constraints,
  1972. **DecisionTreeRegressor._parameter_constraints,
  1973. }
  1974. _parameter_constraints.pop("splitter")
  1975. def __init__(
  1976. self,
  1977. n_estimators=100,
  1978. *,
  1979. criterion="squared_error",
  1980. max_depth=None,
  1981. min_samples_split=2,
  1982. min_samples_leaf=1,
  1983. min_weight_fraction_leaf=0.0,
  1984. max_features=1.0,
  1985. max_leaf_nodes=None,
  1986. min_impurity_decrease=0.0,
  1987. bootstrap=False,
  1988. oob_score=False,
  1989. n_jobs=None,
  1990. random_state=None,
  1991. verbose=0,
  1992. warm_start=False,
  1993. ccp_alpha=0.0,
  1994. max_samples=None,
  1995. ):
  1996. super().__init__(
  1997. estimator=ExtraTreeRegressor(),
  1998. n_estimators=n_estimators,
  1999. estimator_params=(
  2000. "criterion",
  2001. "max_depth",
  2002. "min_samples_split",
  2003. "min_samples_leaf",
  2004. "min_weight_fraction_leaf",
  2005. "max_features",
  2006. "max_leaf_nodes",
  2007. "min_impurity_decrease",
  2008. "random_state",
  2009. "ccp_alpha",
  2010. ),
  2011. bootstrap=bootstrap,
  2012. oob_score=oob_score,
  2013. n_jobs=n_jobs,
  2014. random_state=random_state,
  2015. verbose=verbose,
  2016. warm_start=warm_start,
  2017. max_samples=max_samples,
  2018. )
  2019. self.criterion = criterion
  2020. self.max_depth = max_depth
  2021. self.min_samples_split = min_samples_split
  2022. self.min_samples_leaf = min_samples_leaf
  2023. self.min_weight_fraction_leaf = min_weight_fraction_leaf
  2024. self.max_features = max_features
  2025. self.max_leaf_nodes = max_leaf_nodes
  2026. self.min_impurity_decrease = min_impurity_decrease
  2027. self.ccp_alpha = ccp_alpha
  2028. class RandomTreesEmbedding(TransformerMixin, BaseForest):
  2029. """
  2030. An ensemble of totally random trees.
  2031. An unsupervised transformation of a dataset to a high-dimensional
  2032. sparse representation. A datapoint is coded according to which leaf of
  2033. each tree it is sorted into. Using a one-hot encoding of the leaves,
  2034. this leads to a binary coding with as many ones as there are trees in
  2035. the forest.
  2036. The dimensionality of the resulting representation is
  2037. ``n_out <= n_estimators * max_leaf_nodes``. If ``max_leaf_nodes == None``,
  2038. the number of leaf nodes is at most ``n_estimators * 2 ** max_depth``.
  2039. Read more in the :ref:`User Guide <random_trees_embedding>`.
  2040. Parameters
  2041. ----------
  2042. n_estimators : int, default=100
  2043. Number of trees in the forest.
  2044. .. versionchanged:: 0.22
  2045. The default value of ``n_estimators`` changed from 10 to 100
  2046. in 0.22.
  2047. max_depth : int, default=5
  2048. The maximum depth of each tree. If None, then nodes are expanded until
  2049. all leaves are pure or until all leaves contain less than
  2050. min_samples_split samples.
  2051. min_samples_split : int or float, default=2
  2052. The minimum number of samples required to split an internal node:
  2053. - If int, then consider `min_samples_split` as the minimum number.
  2054. - If float, then `min_samples_split` is a fraction and
  2055. `ceil(min_samples_split * n_samples)` is the minimum
  2056. number of samples for each split.
  2057. .. versionchanged:: 0.18
  2058. Added float values for fractions.
  2059. min_samples_leaf : int or float, default=1
  2060. The minimum number of samples required to be at a leaf node.
  2061. A split point at any depth will only be considered if it leaves at
  2062. least ``min_samples_leaf`` training samples in each of the left and
  2063. right branches. This may have the effect of smoothing the model,
  2064. especially in regression.
  2065. - If int, then consider `min_samples_leaf` as the minimum number.
  2066. - If float, then `min_samples_leaf` is a fraction and
  2067. `ceil(min_samples_leaf * n_samples)` is the minimum
  2068. number of samples for each node.
  2069. .. versionchanged:: 0.18
  2070. Added float values for fractions.
  2071. min_weight_fraction_leaf : float, default=0.0
  2072. The minimum weighted fraction of the sum total of weights (of all
  2073. the input samples) required to be at a leaf node. Samples have
  2074. equal weight when sample_weight is not provided.
  2075. max_leaf_nodes : int, default=None
  2076. Grow trees with ``max_leaf_nodes`` in best-first fashion.
  2077. Best nodes are defined as relative reduction in impurity.
  2078. If None then unlimited number of leaf nodes.
  2079. min_impurity_decrease : float, default=0.0
  2080. A node will be split if this split induces a decrease of the impurity
  2081. greater than or equal to this value.
  2082. The weighted impurity decrease equation is the following::
  2083. N_t / N * (impurity - N_t_R / N_t * right_impurity
  2084. - N_t_L / N_t * left_impurity)
  2085. where ``N`` is the total number of samples, ``N_t`` is the number of
  2086. samples at the current node, ``N_t_L`` is the number of samples in the
  2087. left child, and ``N_t_R`` is the number of samples in the right child.
  2088. ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
  2089. if ``sample_weight`` is passed.
  2090. .. versionadded:: 0.19
  2091. sparse_output : bool, default=True
  2092. Whether or not to return a sparse CSR matrix, as default behavior,
  2093. or to return a dense array compatible with dense pipeline operators.
  2094. n_jobs : int, default=None
  2095. The number of jobs to run in parallel. :meth:`fit`, :meth:`transform`,
  2096. :meth:`decision_path` and :meth:`apply` are all parallelized over the
  2097. trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
  2098. context. ``-1`` means using all processors. See :term:`Glossary
  2099. <n_jobs>` for more details.
  2100. random_state : int, RandomState instance or None, default=None
  2101. Controls the generation of the random `y` used to fit the trees
  2102. and the draw of the splits for each feature at the trees' nodes.
  2103. See :term:`Glossary <random_state>` for details.
  2104. verbose : int, default=0
  2105. Controls the verbosity when fitting and predicting.
  2106. warm_start : bool, default=False
  2107. When set to ``True``, reuse the solution of the previous call to fit
  2108. and add more estimators to the ensemble, otherwise, just fit a whole
  2109. new forest. See :term:`Glossary <warm_start>` and
  2110. :ref:`gradient_boosting_warm_start` for details.
  2111. Attributes
  2112. ----------
  2113. estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance
  2114. The child estimator template used to create the collection of fitted
  2115. sub-estimators.
  2116. .. versionadded:: 1.2
  2117. `base_estimator_` was renamed to `estimator_`.
  2118. base_estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance
  2119. The child estimator template used to create the collection of fitted
  2120. sub-estimators.
  2121. .. deprecated:: 1.2
  2122. `base_estimator_` is deprecated and will be removed in 1.4.
  2123. Use `estimator_` instead.
  2124. estimators_ : list of :class:`~sklearn.tree.ExtraTreeRegressor` instances
  2125. The collection of fitted sub-estimators.
  2126. feature_importances_ : ndarray of shape (n_features,)
  2127. The feature importances (the higher, the more important the feature).
  2128. n_features_in_ : int
  2129. Number of features seen during :term:`fit`.
  2130. .. versionadded:: 0.24
  2131. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  2132. Names of features seen during :term:`fit`. Defined only when `X`
  2133. has feature names that are all strings.
  2134. .. versionadded:: 1.0
  2135. n_outputs_ : int
  2136. The number of outputs when ``fit`` is performed.
  2137. one_hot_encoder_ : OneHotEncoder instance
  2138. One-hot encoder used to create the sparse embedding.
  2139. See Also
  2140. --------
  2141. ExtraTreesClassifier : An extra-trees classifier.
  2142. ExtraTreesRegressor : An extra-trees regressor.
  2143. RandomForestClassifier : A random forest classifier.
  2144. RandomForestRegressor : A random forest regressor.
  2145. sklearn.tree.ExtraTreeClassifier: An extremely randomized
  2146. tree classifier.
  2147. sklearn.tree.ExtraTreeRegressor : An extremely randomized
  2148. tree regressor.
  2149. References
  2150. ----------
  2151. .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
  2152. Machine Learning, 63(1), 3-42, 2006.
  2153. .. [2] Moosmann, F. and Triggs, B. and Jurie, F. "Fast discriminative
  2154. visual codebooks using randomized clustering forests"
  2155. NIPS 2007
  2156. Examples
  2157. --------
  2158. >>> from sklearn.ensemble import RandomTreesEmbedding
  2159. >>> X = [[0,0], [1,0], [0,1], [-1,0], [0,-1]]
  2160. >>> random_trees = RandomTreesEmbedding(
  2161. ... n_estimators=5, random_state=0, max_depth=1).fit(X)
  2162. >>> X_sparse_embedding = random_trees.transform(X)
  2163. >>> X_sparse_embedding.toarray()
  2164. array([[0., 1., 1., 0., 1., 0., 0., 1., 1., 0.],
  2165. [0., 1., 1., 0., 1., 0., 0., 1., 1., 0.],
  2166. [0., 1., 0., 1., 0., 1., 0., 1., 0., 1.],
  2167. [1., 0., 1., 0., 1., 0., 1., 0., 1., 0.],
  2168. [0., 1., 1., 0., 1., 0., 0., 1., 1., 0.]])
  2169. """
  2170. _parameter_constraints: dict = {
  2171. "n_estimators": [Interval(Integral, 1, None, closed="left")],
  2172. "n_jobs": [Integral, None],
  2173. "verbose": ["verbose"],
  2174. "warm_start": ["boolean"],
  2175. **BaseDecisionTree._parameter_constraints,
  2176. "sparse_output": ["boolean"],
  2177. }
  2178. for param in ("max_features", "ccp_alpha", "splitter"):
  2179. _parameter_constraints.pop(param)
  2180. criterion = "squared_error"
  2181. max_features = 1
  2182. def __init__(
  2183. self,
  2184. n_estimators=100,
  2185. *,
  2186. max_depth=5,
  2187. min_samples_split=2,
  2188. min_samples_leaf=1,
  2189. min_weight_fraction_leaf=0.0,
  2190. max_leaf_nodes=None,
  2191. min_impurity_decrease=0.0,
  2192. sparse_output=True,
  2193. n_jobs=None,
  2194. random_state=None,
  2195. verbose=0,
  2196. warm_start=False,
  2197. ):
  2198. super().__init__(
  2199. estimator=ExtraTreeRegressor(),
  2200. n_estimators=n_estimators,
  2201. estimator_params=(
  2202. "criterion",
  2203. "max_depth",
  2204. "min_samples_split",
  2205. "min_samples_leaf",
  2206. "min_weight_fraction_leaf",
  2207. "max_features",
  2208. "max_leaf_nodes",
  2209. "min_impurity_decrease",
  2210. "random_state",
  2211. ),
  2212. bootstrap=False,
  2213. oob_score=False,
  2214. n_jobs=n_jobs,
  2215. random_state=random_state,
  2216. verbose=verbose,
  2217. warm_start=warm_start,
  2218. max_samples=None,
  2219. )
  2220. self.max_depth = max_depth
  2221. self.min_samples_split = min_samples_split
  2222. self.min_samples_leaf = min_samples_leaf
  2223. self.min_weight_fraction_leaf = min_weight_fraction_leaf
  2224. self.max_leaf_nodes = max_leaf_nodes
  2225. self.min_impurity_decrease = min_impurity_decrease
  2226. self.sparse_output = sparse_output
  2227. def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
  2228. raise NotImplementedError("OOB score not supported by tree embedding")
  2229. def fit(self, X, y=None, sample_weight=None):
  2230. """
  2231. Fit estimator.
  2232. Parameters
  2233. ----------
  2234. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  2235. The input samples. Use ``dtype=np.float32`` for maximum
  2236. efficiency. Sparse matrices are also supported, use sparse
  2237. ``csc_matrix`` for maximum efficiency.
  2238. y : Ignored
  2239. Not used, present for API consistency by convention.
  2240. sample_weight : array-like of shape (n_samples,), default=None
  2241. Sample weights. If None, then samples are equally weighted. Splits
  2242. that would create child nodes with net zero or negative weight are
  2243. ignored while searching for a split in each node. In the case of
  2244. classification, splits are also ignored if they would result in any
  2245. single class carrying a negative weight in either child node.
  2246. Returns
  2247. -------
  2248. self : object
  2249. Returns the instance itself.
  2250. """
  2251. # Parameters are validated in fit_transform
  2252. self.fit_transform(X, y, sample_weight=sample_weight)
  2253. return self
  2254. @_fit_context(prefer_skip_nested_validation=True)
  2255. def fit_transform(self, X, y=None, sample_weight=None):
  2256. """
  2257. Fit estimator and transform dataset.
  2258. Parameters
  2259. ----------
  2260. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  2261. Input data used to build forests. Use ``dtype=np.float32`` for
  2262. maximum efficiency.
  2263. y : Ignored
  2264. Not used, present for API consistency by convention.
  2265. sample_weight : array-like of shape (n_samples,), default=None
  2266. Sample weights. If None, then samples are equally weighted. Splits
  2267. that would create child nodes with net zero or negative weight are
  2268. ignored while searching for a split in each node. In the case of
  2269. classification, splits are also ignored if they would result in any
  2270. single class carrying a negative weight in either child node.
  2271. Returns
  2272. -------
  2273. X_transformed : sparse matrix of shape (n_samples, n_out)
  2274. Transformed dataset.
  2275. """
  2276. rnd = check_random_state(self.random_state)
  2277. y = rnd.uniform(size=_num_samples(X))
  2278. super().fit(X, y, sample_weight=sample_weight)
  2279. self.one_hot_encoder_ = OneHotEncoder(sparse_output=self.sparse_output)
  2280. output = self.one_hot_encoder_.fit_transform(self.apply(X))
  2281. self._n_features_out = output.shape[1]
  2282. return output
  2283. def get_feature_names_out(self, input_features=None):
  2284. """Get output feature names for transformation.
  2285. Parameters
  2286. ----------
  2287. input_features : array-like of str or None, default=None
  2288. Only used to validate feature names with the names seen in :meth:`fit`.
  2289. Returns
  2290. -------
  2291. feature_names_out : ndarray of str objects
  2292. Transformed feature names, in the format of
  2293. `randomtreesembedding_{tree}_{leaf}`, where `tree` is the tree used
  2294. to generate the leaf and `leaf` is the index of a leaf node
  2295. in that tree. Note that the node indexing scheme is used to
  2296. index both nodes with children (split nodes) and leaf nodes.
  2297. Only the latter can be present as output features.
  2298. As a consequence, there are missing indices in the output
  2299. feature names.
  2300. """
  2301. check_is_fitted(self, "_n_features_out")
  2302. _check_feature_names_in(
  2303. self, input_features=input_features, generate_names=False
  2304. )
  2305. feature_names = [
  2306. f"randomtreesembedding_{tree}_{leaf}"
  2307. for tree in range(self.n_estimators)
  2308. for leaf in self.one_hot_encoder_.categories_[tree]
  2309. ]
  2310. return np.asarray(feature_names, dtype=object)
  2311. def transform(self, X):
  2312. """
  2313. Transform dataset.
  2314. Parameters
  2315. ----------
  2316. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  2317. Input data to be transformed. Use ``dtype=np.float32`` for maximum
  2318. efficiency. Sparse matrices are also supported, use sparse
  2319. ``csr_matrix`` for maximum efficiency.
  2320. Returns
  2321. -------
  2322. X_transformed : sparse matrix of shape (n_samples, n_out)
  2323. Transformed dataset.
  2324. """
  2325. check_is_fitted(self)
  2326. return self.one_hot_encoder_.transform(self.apply(X))