test_tree.py 89 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641
  1. """
  2. Testing for the tree module (sklearn.tree).
  3. """
  4. import copy
  5. import copyreg
  6. import io
  7. import pickle
  8. import struct
  9. from itertools import chain, product
  10. import joblib
  11. import numpy as np
  12. import pytest
  13. from joblib.numpy_pickle import NumpyPickler
  14. from numpy.testing import assert_allclose
  15. from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
  16. from sklearn import datasets, tree
  17. from sklearn.dummy import DummyRegressor
  18. from sklearn.exceptions import NotFittedError
  19. from sklearn.metrics import accuracy_score, mean_poisson_deviance, mean_squared_error
  20. from sklearn.model_selection import train_test_split
  21. from sklearn.random_projection import _sparse_random_matrix
  22. from sklearn.tree import (
  23. DecisionTreeClassifier,
  24. DecisionTreeRegressor,
  25. ExtraTreeClassifier,
  26. ExtraTreeRegressor,
  27. )
  28. from sklearn.tree._classes import (
  29. CRITERIA_CLF,
  30. CRITERIA_REG,
  31. DENSE_SPLITTERS,
  32. SPARSE_SPLITTERS,
  33. )
  34. from sklearn.tree._tree import (
  35. NODE_DTYPE,
  36. TREE_LEAF,
  37. TREE_UNDEFINED,
  38. _check_n_classes,
  39. _check_node_ndarray,
  40. _check_value_ndarray,
  41. )
  42. from sklearn.tree._tree import Tree as CythonTree
  43. from sklearn.utils import _IS_32BIT, compute_sample_weight
  44. from sklearn.utils._testing import (
  45. assert_almost_equal,
  46. assert_array_almost_equal,
  47. assert_array_equal,
  48. create_memmap_backed_data,
  49. ignore_warnings,
  50. skip_if_32bit,
  51. )
  52. from sklearn.utils.estimator_checks import check_sample_weights_invariance
  53. from sklearn.utils.validation import check_random_state
  54. CLF_CRITERIONS = ("gini", "log_loss")
  55. REG_CRITERIONS = ("squared_error", "absolute_error", "friedman_mse", "poisson")
  56. CLF_TREES = {
  57. "DecisionTreeClassifier": DecisionTreeClassifier,
  58. "ExtraTreeClassifier": ExtraTreeClassifier,
  59. }
  60. REG_TREES = {
  61. "DecisionTreeRegressor": DecisionTreeRegressor,
  62. "ExtraTreeRegressor": ExtraTreeRegressor,
  63. }
  64. ALL_TREES: dict = dict()
  65. ALL_TREES.update(CLF_TREES)
  66. ALL_TREES.update(REG_TREES)
  67. SPARSE_TREES = [
  68. "DecisionTreeClassifier",
  69. "DecisionTreeRegressor",
  70. "ExtraTreeClassifier",
  71. "ExtraTreeRegressor",
  72. ]
  73. X_small = np.array(
  74. [
  75. [0, 0, 4, 0, 0, 0, 1, -14, 0, -4, 0, 0, 0, 0],
  76. [0, 0, 5, 3, 0, -4, 0, 0, 1, -5, 0.2, 0, 4, 1],
  77. [-1, -1, 0, 0, -4.5, 0, 0, 2.1, 1, 0, 0, -4.5, 0, 1],
  78. [-1, -1, 0, -1.2, 0, 0, 0, 0, 0, 0, 0.2, 0, 0, 1],
  79. [-1, -1, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 1],
  80. [-1, -2, 0, 4, -3, 10, 4, 0, -3.2, 0, 4, 3, -4, 1],
  81. [2.11, 0, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0.5, 0, -3, 1],
  82. [2.11, 0, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0, 0, -2, 1],
  83. [2.11, 8, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0, 0, -2, 1],
  84. [2.11, 8, -6, -0.5, 0, 11, 0, 0, -3.2, 6, 0.5, 0, -1, 0],
  85. [2, 8, 5, 1, 0.5, -4, 10, 0, 1, -5, 3, 0, 2, 0],
  86. [2, 0, 1, 1, 1, -1, 1, 0, 0, -2, 3, 0, 1, 0],
  87. [2, 0, 1, 2, 3, -1, 10, 2, 0, -1, 1, 2, 2, 0],
  88. [1, 1, 0, 2, 2, -1, 1, 2, 0, -5, 1, 2, 3, 0],
  89. [3, 1, 0, 3, 0, -4, 10, 0, 1, -5, 3, 0, 3, 1],
  90. [2.11, 8, -6, -0.5, 0, 1, 0, 0, -3.2, 6, 0.5, 0, -3, 1],
  91. [2.11, 8, -6, -0.5, 0, 1, 0, 0, -3.2, 6, 1.5, 1, -1, -1],
  92. [2.11, 8, -6, -0.5, 0, 10, 0, 0, -3.2, 6, 0.5, 0, -1, -1],
  93. [2, 0, 5, 1, 0.5, -2, 10, 0, 1, -5, 3, 1, 0, -1],
  94. [2, 0, 1, 1, 1, -2, 1, 0, 0, -2, 0, 0, 0, 1],
  95. [2, 1, 1, 1, 2, -1, 10, 2, 0, -1, 0, 2, 1, 1],
  96. [1, 1, 0, 0, 1, -3, 1, 2, 0, -5, 1, 2, 1, 1],
  97. [3, 1, 0, 1, 0, -4, 1, 0, 1, -2, 0, 0, 1, 0],
  98. ]
  99. )
  100. y_small = [1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]
  101. y_small_reg = [
  102. 1.0,
  103. 2.1,
  104. 1.2,
  105. 0.05,
  106. 10,
  107. 2.4,
  108. 3.1,
  109. 1.01,
  110. 0.01,
  111. 2.98,
  112. 3.1,
  113. 1.1,
  114. 0.0,
  115. 1.2,
  116. 2,
  117. 11,
  118. 0,
  119. 0,
  120. 4.5,
  121. 0.201,
  122. 1.06,
  123. 0.9,
  124. 0,
  125. ]
  126. # toy sample
  127. X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
  128. y = [-1, -1, -1, 1, 1, 1]
  129. T = [[-1, -1], [2, 2], [3, 2]]
  130. true_result = [-1, 1, 1]
  131. # also load the iris dataset
  132. # and randomly permute it
  133. iris = datasets.load_iris()
  134. rng = np.random.RandomState(1)
  135. perm = rng.permutation(iris.target.size)
  136. iris.data = iris.data[perm]
  137. iris.target = iris.target[perm]
  138. # also load the diabetes dataset
  139. # and randomly permute it
  140. diabetes = datasets.load_diabetes()
  141. perm = rng.permutation(diabetes.target.size)
  142. diabetes.data = diabetes.data[perm]
  143. diabetes.target = diabetes.target[perm]
  144. digits = datasets.load_digits()
  145. perm = rng.permutation(digits.target.size)
  146. digits.data = digits.data[perm]
  147. digits.target = digits.target[perm]
  148. random_state = check_random_state(0)
  149. X_multilabel, y_multilabel = datasets.make_multilabel_classification(
  150. random_state=0, n_samples=30, n_features=10
  151. )
  152. # NB: despite their names X_sparse_* are numpy arrays (and not sparse matrices)
  153. X_sparse_pos = random_state.uniform(size=(20, 5))
  154. X_sparse_pos[X_sparse_pos <= 0.8] = 0.0
  155. y_random = random_state.randint(0, 4, size=(20,))
  156. X_sparse_mix = _sparse_random_matrix(20, 10, density=0.25, random_state=0).toarray()
  157. DATASETS = {
  158. "iris": {"X": iris.data, "y": iris.target},
  159. "diabetes": {"X": diabetes.data, "y": diabetes.target},
  160. "digits": {"X": digits.data, "y": digits.target},
  161. "toy": {"X": X, "y": y},
  162. "clf_small": {"X": X_small, "y": y_small},
  163. "reg_small": {"X": X_small, "y": y_small_reg},
  164. "multilabel": {"X": X_multilabel, "y": y_multilabel},
  165. "sparse-pos": {"X": X_sparse_pos, "y": y_random},
  166. "sparse-neg": {"X": -X_sparse_pos, "y": y_random},
  167. "sparse-mix": {"X": X_sparse_mix, "y": y_random},
  168. "zeros": {"X": np.zeros((20, 3)), "y": y_random},
  169. }
  170. for name in DATASETS:
  171. DATASETS[name]["X_sparse"] = csc_matrix(DATASETS[name]["X"])
  172. def assert_tree_equal(d, s, message):
  173. assert (
  174. s.node_count == d.node_count
  175. ), "{0}: inequal number of node ({1} != {2})".format(
  176. message, s.node_count, d.node_count
  177. )
  178. assert_array_equal(
  179. d.children_right, s.children_right, message + ": inequal children_right"
  180. )
  181. assert_array_equal(
  182. d.children_left, s.children_left, message + ": inequal children_left"
  183. )
  184. external = d.children_right == TREE_LEAF
  185. internal = np.logical_not(external)
  186. assert_array_equal(
  187. d.feature[internal], s.feature[internal], message + ": inequal features"
  188. )
  189. assert_array_equal(
  190. d.threshold[internal], s.threshold[internal], message + ": inequal threshold"
  191. )
  192. assert_array_equal(
  193. d.n_node_samples.sum(),
  194. s.n_node_samples.sum(),
  195. message + ": inequal sum(n_node_samples)",
  196. )
  197. assert_array_equal(
  198. d.n_node_samples, s.n_node_samples, message + ": inequal n_node_samples"
  199. )
  200. assert_almost_equal(d.impurity, s.impurity, err_msg=message + ": inequal impurity")
  201. assert_array_almost_equal(
  202. d.value[external], s.value[external], err_msg=message + ": inequal value"
  203. )
  204. def test_classification_toy():
  205. # Check classification on a toy dataset.
  206. for name, Tree in CLF_TREES.items():
  207. clf = Tree(random_state=0)
  208. clf.fit(X, y)
  209. assert_array_equal(clf.predict(T), true_result, "Failed with {0}".format(name))
  210. clf = Tree(max_features=1, random_state=1)
  211. clf.fit(X, y)
  212. assert_array_equal(clf.predict(T), true_result, "Failed with {0}".format(name))
  213. def test_weighted_classification_toy():
  214. # Check classification on a weighted toy dataset.
  215. for name, Tree in CLF_TREES.items():
  216. clf = Tree(random_state=0)
  217. clf.fit(X, y, sample_weight=np.ones(len(X)))
  218. assert_array_equal(clf.predict(T), true_result, "Failed with {0}".format(name))
  219. clf.fit(X, y, sample_weight=np.full(len(X), 0.5))
  220. assert_array_equal(clf.predict(T), true_result, "Failed with {0}".format(name))
  221. @pytest.mark.parametrize("Tree", REG_TREES.values())
  222. @pytest.mark.parametrize("criterion", REG_CRITERIONS)
  223. def test_regression_toy(Tree, criterion):
  224. # Check regression on a toy dataset.
  225. if criterion == "poisson":
  226. # make target positive while not touching the original y and
  227. # true_result
  228. a = np.abs(np.min(y)) + 1
  229. y_train = np.array(y) + a
  230. y_test = np.array(true_result) + a
  231. else:
  232. y_train = y
  233. y_test = true_result
  234. reg = Tree(criterion=criterion, random_state=1)
  235. reg.fit(X, y_train)
  236. assert_allclose(reg.predict(T), y_test)
  237. clf = Tree(criterion=criterion, max_features=1, random_state=1)
  238. clf.fit(X, y_train)
  239. assert_allclose(reg.predict(T), y_test)
  240. def test_xor():
  241. # Check on a XOR problem
  242. y = np.zeros((10, 10))
  243. y[:5, :5] = 1
  244. y[5:, 5:] = 1
  245. gridx, gridy = np.indices(y.shape)
  246. X = np.vstack([gridx.ravel(), gridy.ravel()]).T
  247. y = y.ravel()
  248. for name, Tree in CLF_TREES.items():
  249. clf = Tree(random_state=0)
  250. clf.fit(X, y)
  251. assert clf.score(X, y) == 1.0, "Failed with {0}".format(name)
  252. clf = Tree(random_state=0, max_features=1)
  253. clf.fit(X, y)
  254. assert clf.score(X, y) == 1.0, "Failed with {0}".format(name)
  255. def test_iris():
  256. # Check consistency on dataset iris.
  257. for (name, Tree), criterion in product(CLF_TREES.items(), CLF_CRITERIONS):
  258. clf = Tree(criterion=criterion, random_state=0)
  259. clf.fit(iris.data, iris.target)
  260. score = accuracy_score(clf.predict(iris.data), iris.target)
  261. assert score > 0.9, "Failed with {0}, criterion = {1} and score = {2}".format(
  262. name, criterion, score
  263. )
  264. clf = Tree(criterion=criterion, max_features=2, random_state=0)
  265. clf.fit(iris.data, iris.target)
  266. score = accuracy_score(clf.predict(iris.data), iris.target)
  267. assert score > 0.5, "Failed with {0}, criterion = {1} and score = {2}".format(
  268. name, criterion, score
  269. )
  270. @pytest.mark.parametrize("name, Tree", REG_TREES.items())
  271. @pytest.mark.parametrize("criterion", REG_CRITERIONS)
  272. def test_diabetes_overfit(name, Tree, criterion):
  273. # check consistency of overfitted trees on the diabetes dataset
  274. # since the trees will overfit, we expect an MSE of 0
  275. reg = Tree(criterion=criterion, random_state=0)
  276. reg.fit(diabetes.data, diabetes.target)
  277. score = mean_squared_error(diabetes.target, reg.predict(diabetes.data))
  278. assert score == pytest.approx(
  279. 0
  280. ), f"Failed with {name}, criterion = {criterion} and score = {score}"
  281. @skip_if_32bit
  282. @pytest.mark.parametrize("name, Tree", REG_TREES.items())
  283. @pytest.mark.parametrize(
  284. "criterion, max_depth, metric, max_loss",
  285. [
  286. ("squared_error", 15, mean_squared_error, 60),
  287. ("absolute_error", 20, mean_squared_error, 60),
  288. ("friedman_mse", 15, mean_squared_error, 60),
  289. ("poisson", 15, mean_poisson_deviance, 30),
  290. ],
  291. )
  292. def test_diabetes_underfit(name, Tree, criterion, max_depth, metric, max_loss):
  293. # check consistency of trees when the depth and the number of features are
  294. # limited
  295. reg = Tree(criterion=criterion, max_depth=max_depth, max_features=6, random_state=0)
  296. reg.fit(diabetes.data, diabetes.target)
  297. loss = metric(diabetes.target, reg.predict(diabetes.data))
  298. assert 0 < loss < max_loss
  299. def test_probability():
  300. # Predict probabilities using DecisionTreeClassifier.
  301. for name, Tree in CLF_TREES.items():
  302. clf = Tree(max_depth=1, max_features=1, random_state=42)
  303. clf.fit(iris.data, iris.target)
  304. prob_predict = clf.predict_proba(iris.data)
  305. assert_array_almost_equal(
  306. np.sum(prob_predict, 1),
  307. np.ones(iris.data.shape[0]),
  308. err_msg="Failed with {0}".format(name),
  309. )
  310. assert_array_equal(
  311. np.argmax(prob_predict, 1),
  312. clf.predict(iris.data),
  313. err_msg="Failed with {0}".format(name),
  314. )
  315. assert_almost_equal(
  316. clf.predict_proba(iris.data),
  317. np.exp(clf.predict_log_proba(iris.data)),
  318. 8,
  319. err_msg="Failed with {0}".format(name),
  320. )
  321. def test_arrayrepr():
  322. # Check the array representation.
  323. # Check resize
  324. X = np.arange(10000)[:, np.newaxis]
  325. y = np.arange(10000)
  326. for name, Tree in REG_TREES.items():
  327. reg = Tree(max_depth=None, random_state=0)
  328. reg.fit(X, y)
  329. def test_pure_set():
  330. # Check when y is pure.
  331. X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
  332. y = [1, 1, 1, 1, 1, 1]
  333. for name, TreeClassifier in CLF_TREES.items():
  334. clf = TreeClassifier(random_state=0)
  335. clf.fit(X, y)
  336. assert_array_equal(clf.predict(X), y, err_msg="Failed with {0}".format(name))
  337. for name, TreeRegressor in REG_TREES.items():
  338. reg = TreeRegressor(random_state=0)
  339. reg.fit(X, y)
  340. assert_almost_equal(reg.predict(X), y, err_msg="Failed with {0}".format(name))
  341. def test_numerical_stability():
  342. # Check numerical stability.
  343. X = np.array(
  344. [
  345. [152.08097839, 140.40744019, 129.75102234, 159.90493774],
  346. [142.50700378, 135.81935120, 117.82884979, 162.75781250],
  347. [127.28772736, 140.40744019, 129.75102234, 159.90493774],
  348. [132.37025452, 143.71923828, 138.35694885, 157.84558105],
  349. [103.10237122, 143.71928406, 138.35696411, 157.84559631],
  350. [127.71276855, 143.71923828, 138.35694885, 157.84558105],
  351. [120.91514587, 140.40744019, 129.75102234, 159.90493774],
  352. ]
  353. )
  354. y = np.array([1.0, 0.70209277, 0.53896582, 0.0, 0.90914464, 0.48026916, 0.49622521])
  355. with np.errstate(all="raise"):
  356. for name, Tree in REG_TREES.items():
  357. reg = Tree(random_state=0)
  358. reg.fit(X, y)
  359. reg.fit(X, -y)
  360. reg.fit(-X, y)
  361. reg.fit(-X, -y)
  362. def test_importances():
  363. # Check variable importances.
  364. X, y = datasets.make_classification(
  365. n_samples=5000,
  366. n_features=10,
  367. n_informative=3,
  368. n_redundant=0,
  369. n_repeated=0,
  370. shuffle=False,
  371. random_state=0,
  372. )
  373. for name, Tree in CLF_TREES.items():
  374. clf = Tree(random_state=0)
  375. clf.fit(X, y)
  376. importances = clf.feature_importances_
  377. n_important = np.sum(importances > 0.1)
  378. assert importances.shape[0] == 10, "Failed with {0}".format(name)
  379. assert n_important == 3, "Failed with {0}".format(name)
  380. # Check on iris that importances are the same for all builders
  381. clf = DecisionTreeClassifier(random_state=0)
  382. clf.fit(iris.data, iris.target)
  383. clf2 = DecisionTreeClassifier(random_state=0, max_leaf_nodes=len(iris.data))
  384. clf2.fit(iris.data, iris.target)
  385. assert_array_equal(clf.feature_importances_, clf2.feature_importances_)
  386. def test_importances_raises():
  387. # Check if variable importance before fit raises ValueError.
  388. clf = DecisionTreeClassifier()
  389. with pytest.raises(ValueError):
  390. getattr(clf, "feature_importances_")
  391. def test_importances_gini_equal_squared_error():
  392. # Check that gini is equivalent to squared_error for binary output variable
  393. X, y = datasets.make_classification(
  394. n_samples=2000,
  395. n_features=10,
  396. n_informative=3,
  397. n_redundant=0,
  398. n_repeated=0,
  399. shuffle=False,
  400. random_state=0,
  401. )
  402. # The gini index and the mean square error (variance) might differ due
  403. # to numerical instability. Since those instabilities mainly occurs at
  404. # high tree depth, we restrict this maximal depth.
  405. clf = DecisionTreeClassifier(criterion="gini", max_depth=5, random_state=0).fit(
  406. X, y
  407. )
  408. reg = DecisionTreeRegressor(
  409. criterion="squared_error", max_depth=5, random_state=0
  410. ).fit(X, y)
  411. assert_almost_equal(clf.feature_importances_, reg.feature_importances_)
  412. assert_array_equal(clf.tree_.feature, reg.tree_.feature)
  413. assert_array_equal(clf.tree_.children_left, reg.tree_.children_left)
  414. assert_array_equal(clf.tree_.children_right, reg.tree_.children_right)
  415. assert_array_equal(clf.tree_.n_node_samples, reg.tree_.n_node_samples)
  416. def test_max_features():
  417. # Check max_features.
  418. for name, TreeEstimator in ALL_TREES.items():
  419. est = TreeEstimator(max_features="sqrt")
  420. est.fit(iris.data, iris.target)
  421. assert est.max_features_ == int(np.sqrt(iris.data.shape[1]))
  422. est = TreeEstimator(max_features="log2")
  423. est.fit(iris.data, iris.target)
  424. assert est.max_features_ == int(np.log2(iris.data.shape[1]))
  425. est = TreeEstimator(max_features=1)
  426. est.fit(iris.data, iris.target)
  427. assert est.max_features_ == 1
  428. est = TreeEstimator(max_features=3)
  429. est.fit(iris.data, iris.target)
  430. assert est.max_features_ == 3
  431. est = TreeEstimator(max_features=0.01)
  432. est.fit(iris.data, iris.target)
  433. assert est.max_features_ == 1
  434. est = TreeEstimator(max_features=0.5)
  435. est.fit(iris.data, iris.target)
  436. assert est.max_features_ == int(0.5 * iris.data.shape[1])
  437. est = TreeEstimator(max_features=1.0)
  438. est.fit(iris.data, iris.target)
  439. assert est.max_features_ == iris.data.shape[1]
  440. est = TreeEstimator(max_features=None)
  441. est.fit(iris.data, iris.target)
  442. assert est.max_features_ == iris.data.shape[1]
  443. def test_error():
  444. # Test that it gives proper exception on deficient input.
  445. for name, TreeEstimator in CLF_TREES.items():
  446. # predict before fit
  447. est = TreeEstimator()
  448. with pytest.raises(NotFittedError):
  449. est.predict_proba(X)
  450. est.fit(X, y)
  451. X2 = [[-2, -1, 1]] # wrong feature shape for sample
  452. with pytest.raises(ValueError):
  453. est.predict_proba(X2)
  454. # Wrong dimensions
  455. est = TreeEstimator()
  456. y2 = y[:-1]
  457. with pytest.raises(ValueError):
  458. est.fit(X, y2)
  459. # Test with arrays that are non-contiguous.
  460. Xf = np.asfortranarray(X)
  461. est = TreeEstimator()
  462. est.fit(Xf, y)
  463. assert_almost_equal(est.predict(T), true_result)
  464. # predict before fitting
  465. est = TreeEstimator()
  466. with pytest.raises(NotFittedError):
  467. est.predict(T)
  468. # predict on vector with different dims
  469. est.fit(X, y)
  470. t = np.asarray(T)
  471. with pytest.raises(ValueError):
  472. est.predict(t[:, 1:])
  473. # wrong sample shape
  474. Xt = np.array(X).T
  475. est = TreeEstimator()
  476. est.fit(np.dot(X, Xt), y)
  477. with pytest.raises(ValueError):
  478. est.predict(X)
  479. with pytest.raises(ValueError):
  480. est.apply(X)
  481. clf = TreeEstimator()
  482. clf.fit(X, y)
  483. with pytest.raises(ValueError):
  484. clf.predict(Xt)
  485. with pytest.raises(ValueError):
  486. clf.apply(Xt)
  487. # apply before fitting
  488. est = TreeEstimator()
  489. with pytest.raises(NotFittedError):
  490. est.apply(T)
  491. # non positive target for Poisson splitting Criterion
  492. est = DecisionTreeRegressor(criterion="poisson")
  493. with pytest.raises(ValueError, match="y is not positive.*Poisson"):
  494. est.fit([[0, 1, 2]], [0, 0, 0])
  495. with pytest.raises(ValueError, match="Some.*y are negative.*Poisson"):
  496. est.fit([[0, 1, 2]], [5, -0.1, 2])
  497. def test_min_samples_split():
  498. """Test min_samples_split parameter"""
  499. X = np.asfortranarray(iris.data, dtype=tree._tree.DTYPE)
  500. y = iris.target
  501. # test both DepthFirstTreeBuilder and BestFirstTreeBuilder
  502. # by setting max_leaf_nodes
  503. for max_leaf_nodes, name in product((None, 1000), ALL_TREES.keys()):
  504. TreeEstimator = ALL_TREES[name]
  505. # test for integer parameter
  506. est = TreeEstimator(
  507. min_samples_split=10, max_leaf_nodes=max_leaf_nodes, random_state=0
  508. )
  509. est.fit(X, y)
  510. # count samples on nodes, -1 means it is a leaf
  511. node_samples = est.tree_.n_node_samples[est.tree_.children_left != -1]
  512. assert np.min(node_samples) > 9, "Failed with {0}".format(name)
  513. # test for float parameter
  514. est = TreeEstimator(
  515. min_samples_split=0.2, max_leaf_nodes=max_leaf_nodes, random_state=0
  516. )
  517. est.fit(X, y)
  518. # count samples on nodes, -1 means it is a leaf
  519. node_samples = est.tree_.n_node_samples[est.tree_.children_left != -1]
  520. assert np.min(node_samples) > 9, "Failed with {0}".format(name)
  521. def test_min_samples_leaf():
  522. # Test if leaves contain more than leaf_count training examples
  523. X = np.asfortranarray(iris.data, dtype=tree._tree.DTYPE)
  524. y = iris.target
  525. # test both DepthFirstTreeBuilder and BestFirstTreeBuilder
  526. # by setting max_leaf_nodes
  527. for max_leaf_nodes, name in product((None, 1000), ALL_TREES.keys()):
  528. TreeEstimator = ALL_TREES[name]
  529. # test integer parameter
  530. est = TreeEstimator(
  531. min_samples_leaf=5, max_leaf_nodes=max_leaf_nodes, random_state=0
  532. )
  533. est.fit(X, y)
  534. out = est.tree_.apply(X)
  535. node_counts = np.bincount(out)
  536. # drop inner nodes
  537. leaf_count = node_counts[node_counts != 0]
  538. assert np.min(leaf_count) > 4, "Failed with {0}".format(name)
  539. # test float parameter
  540. est = TreeEstimator(
  541. min_samples_leaf=0.1, max_leaf_nodes=max_leaf_nodes, random_state=0
  542. )
  543. est.fit(X, y)
  544. out = est.tree_.apply(X)
  545. node_counts = np.bincount(out)
  546. # drop inner nodes
  547. leaf_count = node_counts[node_counts != 0]
  548. assert np.min(leaf_count) > 4, "Failed with {0}".format(name)
  549. def check_min_weight_fraction_leaf(name, datasets, sparse=False):
  550. """Test if leaves contain at least min_weight_fraction_leaf of the
  551. training set"""
  552. if sparse:
  553. X = DATASETS[datasets]["X_sparse"].astype(np.float32)
  554. else:
  555. X = DATASETS[datasets]["X"].astype(np.float32)
  556. y = DATASETS[datasets]["y"]
  557. weights = rng.rand(X.shape[0])
  558. total_weight = np.sum(weights)
  559. TreeEstimator = ALL_TREES[name]
  560. # test both DepthFirstTreeBuilder and BestFirstTreeBuilder
  561. # by setting max_leaf_nodes
  562. for max_leaf_nodes, frac in product((None, 1000), np.linspace(0, 0.5, 6)):
  563. est = TreeEstimator(
  564. min_weight_fraction_leaf=frac, max_leaf_nodes=max_leaf_nodes, random_state=0
  565. )
  566. est.fit(X, y, sample_weight=weights)
  567. if sparse:
  568. out = est.tree_.apply(X.tocsr())
  569. else:
  570. out = est.tree_.apply(X)
  571. node_weights = np.bincount(out, weights=weights)
  572. # drop inner nodes
  573. leaf_weights = node_weights[node_weights != 0]
  574. assert (
  575. np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf
  576. ), "Failed with {0} min_weight_fraction_leaf={1}".format(
  577. name, est.min_weight_fraction_leaf
  578. )
  579. # test case with no weights passed in
  580. total_weight = X.shape[0]
  581. for max_leaf_nodes, frac in product((None, 1000), np.linspace(0, 0.5, 6)):
  582. est = TreeEstimator(
  583. min_weight_fraction_leaf=frac, max_leaf_nodes=max_leaf_nodes, random_state=0
  584. )
  585. est.fit(X, y)
  586. if sparse:
  587. out = est.tree_.apply(X.tocsr())
  588. else:
  589. out = est.tree_.apply(X)
  590. node_weights = np.bincount(out)
  591. # drop inner nodes
  592. leaf_weights = node_weights[node_weights != 0]
  593. assert (
  594. np.min(leaf_weights) >= total_weight * est.min_weight_fraction_leaf
  595. ), "Failed with {0} min_weight_fraction_leaf={1}".format(
  596. name, est.min_weight_fraction_leaf
  597. )
  598. @pytest.mark.parametrize("name", ALL_TREES)
  599. def test_min_weight_fraction_leaf_on_dense_input(name):
  600. check_min_weight_fraction_leaf(name, "iris")
  601. @pytest.mark.parametrize("name", SPARSE_TREES)
  602. def test_min_weight_fraction_leaf_on_sparse_input(name):
  603. check_min_weight_fraction_leaf(name, "multilabel", True)
  604. def check_min_weight_fraction_leaf_with_min_samples_leaf(name, datasets, sparse=False):
  605. """Test the interaction between min_weight_fraction_leaf and
  606. min_samples_leaf when sample_weights is not provided in fit."""
  607. if sparse:
  608. X = DATASETS[datasets]["X_sparse"].astype(np.float32)
  609. else:
  610. X = DATASETS[datasets]["X"].astype(np.float32)
  611. y = DATASETS[datasets]["y"]
  612. total_weight = X.shape[0]
  613. TreeEstimator = ALL_TREES[name]
  614. for max_leaf_nodes, frac in product((None, 1000), np.linspace(0, 0.5, 3)):
  615. # test integer min_samples_leaf
  616. est = TreeEstimator(
  617. min_weight_fraction_leaf=frac,
  618. max_leaf_nodes=max_leaf_nodes,
  619. min_samples_leaf=5,
  620. random_state=0,
  621. )
  622. est.fit(X, y)
  623. if sparse:
  624. out = est.tree_.apply(X.tocsr())
  625. else:
  626. out = est.tree_.apply(X)
  627. node_weights = np.bincount(out)
  628. # drop inner nodes
  629. leaf_weights = node_weights[node_weights != 0]
  630. assert np.min(leaf_weights) >= max(
  631. (total_weight * est.min_weight_fraction_leaf), 5
  632. ), "Failed with {0} min_weight_fraction_leaf={1}, min_samples_leaf={2}".format(
  633. name, est.min_weight_fraction_leaf, est.min_samples_leaf
  634. )
  635. for max_leaf_nodes, frac in product((None, 1000), np.linspace(0, 0.5, 3)):
  636. # test float min_samples_leaf
  637. est = TreeEstimator(
  638. min_weight_fraction_leaf=frac,
  639. max_leaf_nodes=max_leaf_nodes,
  640. min_samples_leaf=0.1,
  641. random_state=0,
  642. )
  643. est.fit(X, y)
  644. if sparse:
  645. out = est.tree_.apply(X.tocsr())
  646. else:
  647. out = est.tree_.apply(X)
  648. node_weights = np.bincount(out)
  649. # drop inner nodes
  650. leaf_weights = node_weights[node_weights != 0]
  651. assert np.min(leaf_weights) >= max(
  652. (total_weight * est.min_weight_fraction_leaf),
  653. (total_weight * est.min_samples_leaf),
  654. ), "Failed with {0} min_weight_fraction_leaf={1}, min_samples_leaf={2}".format(
  655. name, est.min_weight_fraction_leaf, est.min_samples_leaf
  656. )
  657. @pytest.mark.parametrize("name", ALL_TREES)
  658. def test_min_weight_fraction_leaf_with_min_samples_leaf_on_dense_input(name):
  659. check_min_weight_fraction_leaf_with_min_samples_leaf(name, "iris")
  660. @pytest.mark.parametrize("name", SPARSE_TREES)
  661. def test_min_weight_fraction_leaf_with_min_samples_leaf_on_sparse_input(name):
  662. check_min_weight_fraction_leaf_with_min_samples_leaf(name, "multilabel", True)
  663. def test_min_impurity_decrease(global_random_seed):
  664. # test if min_impurity_decrease ensure that a split is made only if
  665. # if the impurity decrease is at least that value
  666. X, y = datasets.make_classification(n_samples=100, random_state=global_random_seed)
  667. # test both DepthFirstTreeBuilder and BestFirstTreeBuilder
  668. # by setting max_leaf_nodes
  669. for max_leaf_nodes, name in product((None, 1000), ALL_TREES.keys()):
  670. TreeEstimator = ALL_TREES[name]
  671. # Check default value of min_impurity_decrease, 1e-7
  672. est1 = TreeEstimator(max_leaf_nodes=max_leaf_nodes, random_state=0)
  673. # Check with explicit value of 0.05
  674. est2 = TreeEstimator(
  675. max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=0.05, random_state=0
  676. )
  677. # Check with a much lower value of 0.0001
  678. est3 = TreeEstimator(
  679. max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=0.0001, random_state=0
  680. )
  681. # Check with a much lower value of 0.1
  682. est4 = TreeEstimator(
  683. max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=0.1, random_state=0
  684. )
  685. for est, expected_decrease in (
  686. (est1, 1e-7),
  687. (est2, 0.05),
  688. (est3, 0.0001),
  689. (est4, 0.1),
  690. ):
  691. assert (
  692. est.min_impurity_decrease <= expected_decrease
  693. ), "Failed, min_impurity_decrease = {0} > {1}".format(
  694. est.min_impurity_decrease, expected_decrease
  695. )
  696. est.fit(X, y)
  697. for node in range(est.tree_.node_count):
  698. # If current node is a not leaf node, check if the split was
  699. # justified w.r.t the min_impurity_decrease
  700. if est.tree_.children_left[node] != TREE_LEAF:
  701. imp_parent = est.tree_.impurity[node]
  702. wtd_n_node = est.tree_.weighted_n_node_samples[node]
  703. left = est.tree_.children_left[node]
  704. wtd_n_left = est.tree_.weighted_n_node_samples[left]
  705. imp_left = est.tree_.impurity[left]
  706. wtd_imp_left = wtd_n_left * imp_left
  707. right = est.tree_.children_right[node]
  708. wtd_n_right = est.tree_.weighted_n_node_samples[right]
  709. imp_right = est.tree_.impurity[right]
  710. wtd_imp_right = wtd_n_right * imp_right
  711. wtd_avg_left_right_imp = wtd_imp_right + wtd_imp_left
  712. wtd_avg_left_right_imp /= wtd_n_node
  713. fractional_node_weight = (
  714. est.tree_.weighted_n_node_samples[node] / X.shape[0]
  715. )
  716. actual_decrease = fractional_node_weight * (
  717. imp_parent - wtd_avg_left_right_imp
  718. )
  719. assert (
  720. actual_decrease >= expected_decrease
  721. ), "Failed with {0} expected min_impurity_decrease={1}".format(
  722. actual_decrease, expected_decrease
  723. )
  724. def test_pickle():
  725. """Test pickling preserves Tree properties and performance."""
  726. for name, TreeEstimator in ALL_TREES.items():
  727. if "Classifier" in name:
  728. X, y = iris.data, iris.target
  729. else:
  730. X, y = diabetes.data, diabetes.target
  731. est = TreeEstimator(random_state=0)
  732. est.fit(X, y)
  733. score = est.score(X, y)
  734. # test that all class properties are maintained
  735. attributes = [
  736. "max_depth",
  737. "node_count",
  738. "capacity",
  739. "n_classes",
  740. "children_left",
  741. "children_right",
  742. "n_leaves",
  743. "feature",
  744. "threshold",
  745. "impurity",
  746. "n_node_samples",
  747. "weighted_n_node_samples",
  748. "value",
  749. ]
  750. fitted_attribute = {
  751. attribute: getattr(est.tree_, attribute) for attribute in attributes
  752. }
  753. serialized_object = pickle.dumps(est)
  754. est2 = pickle.loads(serialized_object)
  755. assert type(est2) == est.__class__
  756. score2 = est2.score(X, y)
  757. assert (
  758. score == score2
  759. ), "Failed to generate same score after pickling with {0}".format(name)
  760. for attribute in fitted_attribute:
  761. assert_array_equal(
  762. getattr(est2.tree_, attribute),
  763. fitted_attribute[attribute],
  764. err_msg=(
  765. f"Failed to generate same attribute {attribute} after pickling with"
  766. f" {name}"
  767. ),
  768. )
  769. def test_multioutput():
  770. # Check estimators on multi-output problems.
  771. X = [
  772. [-2, -1],
  773. [-1, -1],
  774. [-1, -2],
  775. [1, 1],
  776. [1, 2],
  777. [2, 1],
  778. [-2, 1],
  779. [-1, 1],
  780. [-1, 2],
  781. [2, -1],
  782. [1, -1],
  783. [1, -2],
  784. ]
  785. y = [
  786. [-1, 0],
  787. [-1, 0],
  788. [-1, 0],
  789. [1, 1],
  790. [1, 1],
  791. [1, 1],
  792. [-1, 2],
  793. [-1, 2],
  794. [-1, 2],
  795. [1, 3],
  796. [1, 3],
  797. [1, 3],
  798. ]
  799. T = [[-1, -1], [1, 1], [-1, 1], [1, -1]]
  800. y_true = [[-1, 0], [1, 1], [-1, 2], [1, 3]]
  801. # toy classification problem
  802. for name, TreeClassifier in CLF_TREES.items():
  803. clf = TreeClassifier(random_state=0)
  804. y_hat = clf.fit(X, y).predict(T)
  805. assert_array_equal(y_hat, y_true)
  806. assert y_hat.shape == (4, 2)
  807. proba = clf.predict_proba(T)
  808. assert len(proba) == 2
  809. assert proba[0].shape == (4, 2)
  810. assert proba[1].shape == (4, 4)
  811. log_proba = clf.predict_log_proba(T)
  812. assert len(log_proba) == 2
  813. assert log_proba[0].shape == (4, 2)
  814. assert log_proba[1].shape == (4, 4)
  815. # toy regression problem
  816. for name, TreeRegressor in REG_TREES.items():
  817. reg = TreeRegressor(random_state=0)
  818. y_hat = reg.fit(X, y).predict(T)
  819. assert_almost_equal(y_hat, y_true)
  820. assert y_hat.shape == (4, 2)
  821. def test_classes_shape():
  822. # Test that n_classes_ and classes_ have proper shape.
  823. for name, TreeClassifier in CLF_TREES.items():
  824. # Classification, single output
  825. clf = TreeClassifier(random_state=0)
  826. clf.fit(X, y)
  827. assert clf.n_classes_ == 2
  828. assert_array_equal(clf.classes_, [-1, 1])
  829. # Classification, multi-output
  830. _y = np.vstack((y, np.array(y) * 2)).T
  831. clf = TreeClassifier(random_state=0)
  832. clf.fit(X, _y)
  833. assert len(clf.n_classes_) == 2
  834. assert len(clf.classes_) == 2
  835. assert_array_equal(clf.n_classes_, [2, 2])
  836. assert_array_equal(clf.classes_, [[-1, 1], [-2, 2]])
  837. def test_unbalanced_iris():
  838. # Check class rebalancing.
  839. unbalanced_X = iris.data[:125]
  840. unbalanced_y = iris.target[:125]
  841. sample_weight = compute_sample_weight("balanced", unbalanced_y)
  842. for name, TreeClassifier in CLF_TREES.items():
  843. clf = TreeClassifier(random_state=0)
  844. clf.fit(unbalanced_X, unbalanced_y, sample_weight=sample_weight)
  845. assert_almost_equal(clf.predict(unbalanced_X), unbalanced_y)
  846. def test_memory_layout():
  847. # Check that it works no matter the memory layout
  848. for (name, TreeEstimator), dtype in product(
  849. ALL_TREES.items(), [np.float64, np.float32]
  850. ):
  851. est = TreeEstimator(random_state=0)
  852. # Nothing
  853. X = np.asarray(iris.data, dtype=dtype)
  854. y = iris.target
  855. assert_array_equal(est.fit(X, y).predict(X), y)
  856. # C-order
  857. X = np.asarray(iris.data, order="C", dtype=dtype)
  858. y = iris.target
  859. assert_array_equal(est.fit(X, y).predict(X), y)
  860. # F-order
  861. X = np.asarray(iris.data, order="F", dtype=dtype)
  862. y = iris.target
  863. assert_array_equal(est.fit(X, y).predict(X), y)
  864. # Contiguous
  865. X = np.ascontiguousarray(iris.data, dtype=dtype)
  866. y = iris.target
  867. assert_array_equal(est.fit(X, y).predict(X), y)
  868. # csr matrix
  869. X = csr_matrix(iris.data, dtype=dtype)
  870. y = iris.target
  871. assert_array_equal(est.fit(X, y).predict(X), y)
  872. # csc_matrix
  873. X = csc_matrix(iris.data, dtype=dtype)
  874. y = iris.target
  875. assert_array_equal(est.fit(X, y).predict(X), y)
  876. # Strided
  877. X = np.asarray(iris.data[::3], dtype=dtype)
  878. y = iris.target[::3]
  879. assert_array_equal(est.fit(X, y).predict(X), y)
  880. def test_sample_weight():
  881. # Check sample weighting.
  882. # Test that zero-weighted samples are not taken into account
  883. X = np.arange(100)[:, np.newaxis]
  884. y = np.ones(100)
  885. y[:50] = 0.0
  886. sample_weight = np.ones(100)
  887. sample_weight[y == 0] = 0.0
  888. clf = DecisionTreeClassifier(random_state=0)
  889. clf.fit(X, y, sample_weight=sample_weight)
  890. assert_array_equal(clf.predict(X), np.ones(100))
  891. # Test that low weighted samples are not taken into account at low depth
  892. X = np.arange(200)[:, np.newaxis]
  893. y = np.zeros(200)
  894. y[50:100] = 1
  895. y[100:200] = 2
  896. X[100:200, 0] = 200
  897. sample_weight = np.ones(200)
  898. sample_weight[y == 2] = 0.51 # Samples of class '2' are still weightier
  899. clf = DecisionTreeClassifier(max_depth=1, random_state=0)
  900. clf.fit(X, y, sample_weight=sample_weight)
  901. assert clf.tree_.threshold[0] == 149.5
  902. sample_weight[y == 2] = 0.5 # Samples of class '2' are no longer weightier
  903. clf = DecisionTreeClassifier(max_depth=1, random_state=0)
  904. clf.fit(X, y, sample_weight=sample_weight)
  905. assert clf.tree_.threshold[0] == 49.5 # Threshold should have moved
  906. # Test that sample weighting is the same as having duplicates
  907. X = iris.data
  908. y = iris.target
  909. duplicates = rng.randint(0, X.shape[0], 100)
  910. clf = DecisionTreeClassifier(random_state=1)
  911. clf.fit(X[duplicates], y[duplicates])
  912. sample_weight = np.bincount(duplicates, minlength=X.shape[0])
  913. clf2 = DecisionTreeClassifier(random_state=1)
  914. clf2.fit(X, y, sample_weight=sample_weight)
  915. internal = clf.tree_.children_left != tree._tree.TREE_LEAF
  916. assert_array_almost_equal(
  917. clf.tree_.threshold[internal], clf2.tree_.threshold[internal]
  918. )
  919. def test_sample_weight_invalid():
  920. # Check sample weighting raises errors.
  921. X = np.arange(100)[:, np.newaxis]
  922. y = np.ones(100)
  923. y[:50] = 0.0
  924. clf = DecisionTreeClassifier(random_state=0)
  925. sample_weight = np.random.rand(100, 1)
  926. with pytest.raises(ValueError):
  927. clf.fit(X, y, sample_weight=sample_weight)
  928. sample_weight = np.array(0)
  929. expected_err = r"Singleton.* cannot be considered a valid collection"
  930. with pytest.raises(TypeError, match=expected_err):
  931. clf.fit(X, y, sample_weight=sample_weight)
  932. def check_class_weights(name):
  933. """Check class_weights resemble sample_weights behavior."""
  934. TreeClassifier = CLF_TREES[name]
  935. # Iris is balanced, so no effect expected for using 'balanced' weights
  936. clf1 = TreeClassifier(random_state=0)
  937. clf1.fit(iris.data, iris.target)
  938. clf2 = TreeClassifier(class_weight="balanced", random_state=0)
  939. clf2.fit(iris.data, iris.target)
  940. assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
  941. # Make a multi-output problem with three copies of Iris
  942. iris_multi = np.vstack((iris.target, iris.target, iris.target)).T
  943. # Create user-defined weights that should balance over the outputs
  944. clf3 = TreeClassifier(
  945. class_weight=[
  946. {0: 2.0, 1: 2.0, 2: 1.0},
  947. {0: 2.0, 1: 1.0, 2: 2.0},
  948. {0: 1.0, 1: 2.0, 2: 2.0},
  949. ],
  950. random_state=0,
  951. )
  952. clf3.fit(iris.data, iris_multi)
  953. assert_almost_equal(clf2.feature_importances_, clf3.feature_importances_)
  954. # Check against multi-output "auto" which should also have no effect
  955. clf4 = TreeClassifier(class_weight="balanced", random_state=0)
  956. clf4.fit(iris.data, iris_multi)
  957. assert_almost_equal(clf3.feature_importances_, clf4.feature_importances_)
  958. # Inflate importance of class 1, check against user-defined weights
  959. sample_weight = np.ones(iris.target.shape)
  960. sample_weight[iris.target == 1] *= 100
  961. class_weight = {0: 1.0, 1: 100.0, 2: 1.0}
  962. clf1 = TreeClassifier(random_state=0)
  963. clf1.fit(iris.data, iris.target, sample_weight)
  964. clf2 = TreeClassifier(class_weight=class_weight, random_state=0)
  965. clf2.fit(iris.data, iris.target)
  966. assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
  967. # Check that sample_weight and class_weight are multiplicative
  968. clf1 = TreeClassifier(random_state=0)
  969. clf1.fit(iris.data, iris.target, sample_weight**2)
  970. clf2 = TreeClassifier(class_weight=class_weight, random_state=0)
  971. clf2.fit(iris.data, iris.target, sample_weight)
  972. assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)
  973. @pytest.mark.parametrize("name", CLF_TREES)
  974. def test_class_weights(name):
  975. check_class_weights(name)
  976. def check_class_weight_errors(name):
  977. # Test if class_weight raises errors and warnings when expected.
  978. TreeClassifier = CLF_TREES[name]
  979. _y = np.vstack((y, np.array(y) * 2)).T
  980. # Incorrect length list for multi-output
  981. clf = TreeClassifier(class_weight=[{-1: 0.5, 1: 1.0}], random_state=0)
  982. err_msg = "number of elements in class_weight should match number of outputs."
  983. with pytest.raises(ValueError, match=err_msg):
  984. clf.fit(X, _y)
  985. @pytest.mark.parametrize("name", CLF_TREES)
  986. def test_class_weight_errors(name):
  987. check_class_weight_errors(name)
  988. def test_max_leaf_nodes():
  989. # Test greedy trees with max_depth + 1 leafs.
  990. X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
  991. k = 4
  992. for name, TreeEstimator in ALL_TREES.items():
  993. est = TreeEstimator(max_depth=None, max_leaf_nodes=k + 1).fit(X, y)
  994. assert est.get_n_leaves() == k + 1
  995. def test_max_leaf_nodes_max_depth():
  996. # Test precedence of max_leaf_nodes over max_depth.
  997. X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
  998. k = 4
  999. for name, TreeEstimator in ALL_TREES.items():
  1000. est = TreeEstimator(max_depth=1, max_leaf_nodes=k).fit(X, y)
  1001. assert est.get_depth() == 1
  1002. def test_arrays_persist():
  1003. # Ensure property arrays' memory stays alive when tree disappears
  1004. # non-regression for #2726
  1005. for attr in [
  1006. "n_classes",
  1007. "value",
  1008. "children_left",
  1009. "children_right",
  1010. "threshold",
  1011. "impurity",
  1012. "feature",
  1013. "n_node_samples",
  1014. ]:
  1015. value = getattr(DecisionTreeClassifier().fit([[0], [1]], [0, 1]).tree_, attr)
  1016. # if pointing to freed memory, contents may be arbitrary
  1017. assert -3 <= value.flat[0] < 3, "Array points to arbitrary memory"
  1018. def test_only_constant_features():
  1019. random_state = check_random_state(0)
  1020. X = np.zeros((10, 20))
  1021. y = random_state.randint(0, 2, (10,))
  1022. for name, TreeEstimator in ALL_TREES.items():
  1023. est = TreeEstimator(random_state=0)
  1024. est.fit(X, y)
  1025. assert est.tree_.max_depth == 0
  1026. def test_behaviour_constant_feature_after_splits():
  1027. X = np.transpose(
  1028. np.vstack(([[0, 0, 0, 0, 0, 1, 2, 4, 5, 6, 7]], np.zeros((4, 11))))
  1029. )
  1030. y = [0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 3]
  1031. for name, TreeEstimator in ALL_TREES.items():
  1032. # do not check extra random trees
  1033. if "ExtraTree" not in name:
  1034. est = TreeEstimator(random_state=0, max_features=1)
  1035. est.fit(X, y)
  1036. assert est.tree_.max_depth == 2
  1037. assert est.tree_.node_count == 5
  1038. def test_with_only_one_non_constant_features():
  1039. X = np.hstack([np.array([[1.0], [1.0], [0.0], [0.0]]), np.zeros((4, 1000))])
  1040. y = np.array([0.0, 1.0, 0.0, 1.0])
  1041. for name, TreeEstimator in CLF_TREES.items():
  1042. est = TreeEstimator(random_state=0, max_features=1)
  1043. est.fit(X, y)
  1044. assert est.tree_.max_depth == 1
  1045. assert_array_equal(est.predict_proba(X), np.full((4, 2), 0.5))
  1046. for name, TreeEstimator in REG_TREES.items():
  1047. est = TreeEstimator(random_state=0, max_features=1)
  1048. est.fit(X, y)
  1049. assert est.tree_.max_depth == 1
  1050. assert_array_equal(est.predict(X), np.full((4,), 0.5))
  1051. def test_big_input():
  1052. # Test if the warning for too large inputs is appropriate.
  1053. X = np.repeat(10**40.0, 4).astype(np.float64).reshape(-1, 1)
  1054. clf = DecisionTreeClassifier()
  1055. with pytest.raises(ValueError, match="float32"):
  1056. clf.fit(X, [0, 1, 0, 1])
  1057. def test_realloc():
  1058. from sklearn.tree._utils import _realloc_test
  1059. with pytest.raises(MemoryError):
  1060. _realloc_test()
  1061. def test_huge_allocations():
  1062. n_bits = 8 * struct.calcsize("P")
  1063. X = np.random.randn(10, 2)
  1064. y = np.random.randint(0, 2, 10)
  1065. # Sanity check: we cannot request more memory than the size of the address
  1066. # space. Currently raises OverflowError.
  1067. huge = 2 ** (n_bits + 1)
  1068. clf = DecisionTreeClassifier(splitter="best", max_leaf_nodes=huge)
  1069. with pytest.raises(Exception):
  1070. clf.fit(X, y)
  1071. # Non-regression test: MemoryError used to be dropped by Cython
  1072. # because of missing "except *".
  1073. huge = 2 ** (n_bits - 1) - 1
  1074. clf = DecisionTreeClassifier(splitter="best", max_leaf_nodes=huge)
  1075. with pytest.raises(MemoryError):
  1076. clf.fit(X, y)
  1077. def check_sparse_input(tree, dataset, max_depth=None):
  1078. TreeEstimator = ALL_TREES[tree]
  1079. X = DATASETS[dataset]["X"]
  1080. X_sparse = DATASETS[dataset]["X_sparse"]
  1081. y = DATASETS[dataset]["y"]
  1082. # Gain testing time
  1083. if dataset in ["digits", "diabetes"]:
  1084. n_samples = X.shape[0] // 5
  1085. X = X[:n_samples]
  1086. X_sparse = X_sparse[:n_samples]
  1087. y = y[:n_samples]
  1088. for sparse_format in (csr_matrix, csc_matrix, coo_matrix):
  1089. X_sparse = sparse_format(X_sparse)
  1090. # Check the default (depth first search)
  1091. d = TreeEstimator(random_state=0, max_depth=max_depth).fit(X, y)
  1092. s = TreeEstimator(random_state=0, max_depth=max_depth).fit(X_sparse, y)
  1093. assert_tree_equal(
  1094. d.tree_,
  1095. s.tree_,
  1096. "{0} with dense and sparse format gave different trees".format(tree),
  1097. )
  1098. y_pred = d.predict(X)
  1099. if tree in CLF_TREES:
  1100. y_proba = d.predict_proba(X)
  1101. y_log_proba = d.predict_log_proba(X)
  1102. for sparse_matrix in (csr_matrix, csc_matrix, coo_matrix):
  1103. X_sparse_test = sparse_matrix(X_sparse, dtype=np.float32)
  1104. assert_array_almost_equal(s.predict(X_sparse_test), y_pred)
  1105. if tree in CLF_TREES:
  1106. assert_array_almost_equal(s.predict_proba(X_sparse_test), y_proba)
  1107. assert_array_almost_equal(
  1108. s.predict_log_proba(X_sparse_test), y_log_proba
  1109. )
  1110. @pytest.mark.parametrize("tree_type", SPARSE_TREES)
  1111. @pytest.mark.parametrize(
  1112. "dataset",
  1113. (
  1114. "clf_small",
  1115. "toy",
  1116. "digits",
  1117. "multilabel",
  1118. "sparse-pos",
  1119. "sparse-neg",
  1120. "sparse-mix",
  1121. "zeros",
  1122. ),
  1123. )
  1124. def test_sparse_input(tree_type, dataset):
  1125. max_depth = 3 if dataset == "digits" else None
  1126. check_sparse_input(tree_type, dataset, max_depth)
  1127. @pytest.mark.parametrize("tree_type", sorted(set(SPARSE_TREES).intersection(REG_TREES)))
  1128. @pytest.mark.parametrize("dataset", ["diabetes", "reg_small"])
  1129. def test_sparse_input_reg_trees(tree_type, dataset):
  1130. # Due to numerical instability of MSE and too strict test, we limit the
  1131. # maximal depth
  1132. check_sparse_input(tree_type, dataset, 2)
  1133. def check_sparse_parameters(tree, dataset):
  1134. TreeEstimator = ALL_TREES[tree]
  1135. X = DATASETS[dataset]["X"]
  1136. X_sparse = DATASETS[dataset]["X_sparse"]
  1137. y = DATASETS[dataset]["y"]
  1138. # Check max_features
  1139. d = TreeEstimator(random_state=0, max_features=1, max_depth=2).fit(X, y)
  1140. s = TreeEstimator(random_state=0, max_features=1, max_depth=2).fit(X_sparse, y)
  1141. assert_tree_equal(
  1142. d.tree_,
  1143. s.tree_,
  1144. "{0} with dense and sparse format gave different trees".format(tree),
  1145. )
  1146. assert_array_almost_equal(s.predict(X), d.predict(X))
  1147. # Check min_samples_split
  1148. d = TreeEstimator(random_state=0, max_features=1, min_samples_split=10).fit(X, y)
  1149. s = TreeEstimator(random_state=0, max_features=1, min_samples_split=10).fit(
  1150. X_sparse, y
  1151. )
  1152. assert_tree_equal(
  1153. d.tree_,
  1154. s.tree_,
  1155. "{0} with dense and sparse format gave different trees".format(tree),
  1156. )
  1157. assert_array_almost_equal(s.predict(X), d.predict(X))
  1158. # Check min_samples_leaf
  1159. d = TreeEstimator(random_state=0, min_samples_leaf=X_sparse.shape[0] // 2).fit(X, y)
  1160. s = TreeEstimator(random_state=0, min_samples_leaf=X_sparse.shape[0] // 2).fit(
  1161. X_sparse, y
  1162. )
  1163. assert_tree_equal(
  1164. d.tree_,
  1165. s.tree_,
  1166. "{0} with dense and sparse format gave different trees".format(tree),
  1167. )
  1168. assert_array_almost_equal(s.predict(X), d.predict(X))
  1169. # Check best-first search
  1170. d = TreeEstimator(random_state=0, max_leaf_nodes=3).fit(X, y)
  1171. s = TreeEstimator(random_state=0, max_leaf_nodes=3).fit(X_sparse, y)
  1172. assert_tree_equal(
  1173. d.tree_,
  1174. s.tree_,
  1175. "{0} with dense and sparse format gave different trees".format(tree),
  1176. )
  1177. assert_array_almost_equal(s.predict(X), d.predict(X))
  1178. def check_sparse_criterion(tree, dataset):
  1179. TreeEstimator = ALL_TREES[tree]
  1180. X = DATASETS[dataset]["X"]
  1181. X_sparse = DATASETS[dataset]["X_sparse"]
  1182. y = DATASETS[dataset]["y"]
  1183. # Check various criterion
  1184. CRITERIONS = REG_CRITERIONS if tree in REG_TREES else CLF_CRITERIONS
  1185. for criterion in CRITERIONS:
  1186. d = TreeEstimator(random_state=0, max_depth=3, criterion=criterion).fit(X, y)
  1187. s = TreeEstimator(random_state=0, max_depth=3, criterion=criterion).fit(
  1188. X_sparse, y
  1189. )
  1190. assert_tree_equal(
  1191. d.tree_,
  1192. s.tree_,
  1193. "{0} with dense and sparse format gave different trees".format(tree),
  1194. )
  1195. assert_array_almost_equal(s.predict(X), d.predict(X))
  1196. @pytest.mark.parametrize("tree_type", SPARSE_TREES)
  1197. @pytest.mark.parametrize("dataset", ["sparse-pos", "sparse-neg", "sparse-mix", "zeros"])
  1198. @pytest.mark.parametrize("check", [check_sparse_parameters, check_sparse_criterion])
  1199. def test_sparse(tree_type, dataset, check):
  1200. check(tree_type, dataset)
  1201. def check_explicit_sparse_zeros(tree, max_depth=3, n_features=10):
  1202. TreeEstimator = ALL_TREES[tree]
  1203. # n_samples set n_feature to ease construction of a simultaneous
  1204. # construction of a csr and csc matrix
  1205. n_samples = n_features
  1206. samples = np.arange(n_samples)
  1207. # Generate X, y
  1208. random_state = check_random_state(0)
  1209. indices = []
  1210. data = []
  1211. offset = 0
  1212. indptr = [offset]
  1213. for i in range(n_features):
  1214. n_nonzero_i = random_state.binomial(n_samples, 0.5)
  1215. indices_i = random_state.permutation(samples)[:n_nonzero_i]
  1216. indices.append(indices_i)
  1217. data_i = random_state.binomial(3, 0.5, size=(n_nonzero_i,)) - 1
  1218. data.append(data_i)
  1219. offset += n_nonzero_i
  1220. indptr.append(offset)
  1221. indices = np.concatenate(indices)
  1222. data = np.array(np.concatenate(data), dtype=np.float32)
  1223. X_sparse = csc_matrix((data, indices, indptr), shape=(n_samples, n_features))
  1224. X = X_sparse.toarray()
  1225. X_sparse_test = csr_matrix((data, indices, indptr), shape=(n_samples, n_features))
  1226. X_test = X_sparse_test.toarray()
  1227. y = random_state.randint(0, 3, size=(n_samples,))
  1228. # Ensure that X_sparse_test owns its data, indices and indptr array
  1229. X_sparse_test = X_sparse_test.copy()
  1230. # Ensure that we have explicit zeros
  1231. assert (X_sparse.data == 0.0).sum() > 0
  1232. assert (X_sparse_test.data == 0.0).sum() > 0
  1233. # Perform the comparison
  1234. d = TreeEstimator(random_state=0, max_depth=max_depth).fit(X, y)
  1235. s = TreeEstimator(random_state=0, max_depth=max_depth).fit(X_sparse, y)
  1236. assert_tree_equal(
  1237. d.tree_,
  1238. s.tree_,
  1239. "{0} with dense and sparse format gave different trees".format(tree),
  1240. )
  1241. Xs = (X_test, X_sparse_test)
  1242. for X1, X2 in product(Xs, Xs):
  1243. assert_array_almost_equal(s.tree_.apply(X1), d.tree_.apply(X2))
  1244. assert_array_almost_equal(s.apply(X1), d.apply(X2))
  1245. assert_array_almost_equal(s.apply(X1), s.tree_.apply(X1))
  1246. assert_array_almost_equal(
  1247. s.tree_.decision_path(X1).toarray(), d.tree_.decision_path(X2).toarray()
  1248. )
  1249. assert_array_almost_equal(
  1250. s.decision_path(X1).toarray(), d.decision_path(X2).toarray()
  1251. )
  1252. assert_array_almost_equal(
  1253. s.decision_path(X1).toarray(), s.tree_.decision_path(X1).toarray()
  1254. )
  1255. assert_array_almost_equal(s.predict(X1), d.predict(X2))
  1256. if tree in CLF_TREES:
  1257. assert_array_almost_equal(s.predict_proba(X1), d.predict_proba(X2))
  1258. @pytest.mark.parametrize("tree_type", SPARSE_TREES)
  1259. def test_explicit_sparse_zeros(tree_type):
  1260. check_explicit_sparse_zeros(tree_type)
  1261. @ignore_warnings
  1262. def check_raise_error_on_1d_input(name):
  1263. TreeEstimator = ALL_TREES[name]
  1264. X = iris.data[:, 0].ravel()
  1265. X_2d = iris.data[:, 0].reshape((-1, 1))
  1266. y = iris.target
  1267. with pytest.raises(ValueError):
  1268. TreeEstimator(random_state=0).fit(X, y)
  1269. est = TreeEstimator(random_state=0)
  1270. est.fit(X_2d, y)
  1271. with pytest.raises(ValueError):
  1272. est.predict([X])
  1273. @pytest.mark.parametrize("name", ALL_TREES)
  1274. def test_1d_input(name):
  1275. with ignore_warnings():
  1276. check_raise_error_on_1d_input(name)
  1277. def _check_min_weight_leaf_split_level(TreeEstimator, X, y, sample_weight):
  1278. est = TreeEstimator(random_state=0)
  1279. est.fit(X, y, sample_weight=sample_weight)
  1280. assert est.tree_.max_depth == 1
  1281. est = TreeEstimator(random_state=0, min_weight_fraction_leaf=0.4)
  1282. est.fit(X, y, sample_weight=sample_weight)
  1283. assert est.tree_.max_depth == 0
  1284. def check_min_weight_leaf_split_level(name):
  1285. TreeEstimator = ALL_TREES[name]
  1286. X = np.array([[0], [0], [0], [0], [1]])
  1287. y = [0, 0, 0, 0, 1]
  1288. sample_weight = [0.2, 0.2, 0.2, 0.2, 0.2]
  1289. _check_min_weight_leaf_split_level(TreeEstimator, X, y, sample_weight)
  1290. _check_min_weight_leaf_split_level(TreeEstimator, csc_matrix(X), y, sample_weight)
  1291. @pytest.mark.parametrize("name", ALL_TREES)
  1292. def test_min_weight_leaf_split_level(name):
  1293. check_min_weight_leaf_split_level(name)
  1294. def check_public_apply(name):
  1295. X_small32 = X_small.astype(tree._tree.DTYPE, copy=False)
  1296. est = ALL_TREES[name]()
  1297. est.fit(X_small, y_small)
  1298. assert_array_equal(est.apply(X_small), est.tree_.apply(X_small32))
  1299. def check_public_apply_sparse(name):
  1300. X_small32 = csr_matrix(X_small.astype(tree._tree.DTYPE, copy=False))
  1301. est = ALL_TREES[name]()
  1302. est.fit(X_small, y_small)
  1303. assert_array_equal(est.apply(X_small), est.tree_.apply(X_small32))
  1304. @pytest.mark.parametrize("name", ALL_TREES)
  1305. def test_public_apply_all_trees(name):
  1306. check_public_apply(name)
  1307. @pytest.mark.parametrize("name", SPARSE_TREES)
  1308. def test_public_apply_sparse_trees(name):
  1309. check_public_apply_sparse(name)
  1310. def test_decision_path_hardcoded():
  1311. X = iris.data
  1312. y = iris.target
  1313. est = DecisionTreeClassifier(random_state=0, max_depth=1).fit(X, y)
  1314. node_indicator = est.decision_path(X[:2]).toarray()
  1315. assert_array_equal(node_indicator, [[1, 1, 0], [1, 0, 1]])
  1316. def check_decision_path(name):
  1317. X = iris.data
  1318. y = iris.target
  1319. n_samples = X.shape[0]
  1320. TreeEstimator = ALL_TREES[name]
  1321. est = TreeEstimator(random_state=0, max_depth=2)
  1322. est.fit(X, y)
  1323. node_indicator_csr = est.decision_path(X)
  1324. node_indicator = node_indicator_csr.toarray()
  1325. assert node_indicator.shape == (n_samples, est.tree_.node_count)
  1326. # Assert that leaves index are correct
  1327. leaves = est.apply(X)
  1328. leave_indicator = [node_indicator[i, j] for i, j in enumerate(leaves)]
  1329. assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples))
  1330. # Ensure only one leave node per sample
  1331. all_leaves = est.tree_.children_left == TREE_LEAF
  1332. assert_array_almost_equal(
  1333. np.dot(node_indicator, all_leaves), np.ones(shape=n_samples)
  1334. )
  1335. # Ensure max depth is consistent with sum of indicator
  1336. max_depth = node_indicator.sum(axis=1).max()
  1337. assert est.tree_.max_depth <= max_depth
  1338. @pytest.mark.parametrize("name", ALL_TREES)
  1339. def test_decision_path(name):
  1340. check_decision_path(name)
  1341. def check_no_sparse_y_support(name):
  1342. X, y = X_multilabel, csr_matrix(y_multilabel)
  1343. TreeEstimator = ALL_TREES[name]
  1344. with pytest.raises(TypeError):
  1345. TreeEstimator(random_state=0).fit(X, y)
  1346. @pytest.mark.parametrize("name", ALL_TREES)
  1347. def test_no_sparse_y_support(name):
  1348. # Currently we don't support sparse y
  1349. check_no_sparse_y_support(name)
  1350. def test_mae():
  1351. """Check MAE criterion produces correct results on small toy dataset:
  1352. ------------------
  1353. | X | y | weight |
  1354. ------------------
  1355. | 3 | 3 | 0.1 |
  1356. | 5 | 3 | 0.3 |
  1357. | 8 | 4 | 1.0 |
  1358. | 3 | 6 | 0.6 |
  1359. | 5 | 7 | 0.3 |
  1360. ------------------
  1361. |sum wt:| 2.3 |
  1362. ------------------
  1363. Because we are dealing with sample weights, we cannot find the median by
  1364. simply choosing/averaging the centre value(s), instead we consider the
  1365. median where 50% of the cumulative weight is found (in a y sorted data set)
  1366. . Therefore with regards to this test data, the cumulative weight is >= 50%
  1367. when y = 4. Therefore:
  1368. Median = 4
  1369. For all the samples, we can get the total error by summing:
  1370. Absolute(Median - y) * weight
  1371. I.e., total error = (Absolute(4 - 3) * 0.1)
  1372. + (Absolute(4 - 3) * 0.3)
  1373. + (Absolute(4 - 4) * 1.0)
  1374. + (Absolute(4 - 6) * 0.6)
  1375. + (Absolute(4 - 7) * 0.3)
  1376. = 2.5
  1377. Impurity = Total error / total weight
  1378. = 2.5 / 2.3
  1379. = 1.08695652173913
  1380. ------------------
  1381. From this root node, the next best split is between X values of 3 and 5.
  1382. Thus, we have left and right child nodes:
  1383. LEFT RIGHT
  1384. ------------------ ------------------
  1385. | X | y | weight | | X | y | weight |
  1386. ------------------ ------------------
  1387. | 3 | 3 | 0.1 | | 5 | 3 | 0.3 |
  1388. | 3 | 6 | 0.6 | | 8 | 4 | 1.0 |
  1389. ------------------ | 5 | 7 | 0.3 |
  1390. |sum wt:| 0.7 | ------------------
  1391. ------------------ |sum wt:| 1.6 |
  1392. ------------------
  1393. Impurity is found in the same way:
  1394. Left node Median = 6
  1395. Total error = (Absolute(6 - 3) * 0.1)
  1396. + (Absolute(6 - 6) * 0.6)
  1397. = 0.3
  1398. Left Impurity = Total error / total weight
  1399. = 0.3 / 0.7
  1400. = 0.428571428571429
  1401. -------------------
  1402. Likewise for Right node:
  1403. Right node Median = 4
  1404. Total error = (Absolute(4 - 3) * 0.3)
  1405. + (Absolute(4 - 4) * 1.0)
  1406. + (Absolute(4 - 7) * 0.3)
  1407. = 1.2
  1408. Right Impurity = Total error / total weight
  1409. = 1.2 / 1.6
  1410. = 0.75
  1411. ------
  1412. """
  1413. dt_mae = DecisionTreeRegressor(
  1414. random_state=0, criterion="absolute_error", max_leaf_nodes=2
  1415. )
  1416. # Test MAE where sample weights are non-uniform (as illustrated above):
  1417. dt_mae.fit(
  1418. X=[[3], [5], [3], [8], [5]],
  1419. y=[6, 7, 3, 4, 3],
  1420. sample_weight=[0.6, 0.3, 0.1, 1.0, 0.3],
  1421. )
  1422. assert_allclose(dt_mae.tree_.impurity, [2.5 / 2.3, 0.3 / 0.7, 1.2 / 1.6])
  1423. assert_array_equal(dt_mae.tree_.value.flat, [4.0, 6.0, 4.0])
  1424. # Test MAE where all sample weights are uniform:
  1425. dt_mae.fit(X=[[3], [5], [3], [8], [5]], y=[6, 7, 3, 4, 3], sample_weight=np.ones(5))
  1426. assert_array_equal(dt_mae.tree_.impurity, [1.4, 1.5, 4.0 / 3.0])
  1427. assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0])
  1428. # Test MAE where a `sample_weight` is not explicitly provided.
  1429. # This is equivalent to providing uniform sample weights, though
  1430. # the internal logic is different:
  1431. dt_mae.fit(X=[[3], [5], [3], [8], [5]], y=[6, 7, 3, 4, 3])
  1432. assert_array_equal(dt_mae.tree_.impurity, [1.4, 1.5, 4.0 / 3.0])
  1433. assert_array_equal(dt_mae.tree_.value.flat, [4, 4.5, 4.0])
  1434. def test_criterion_copy():
  1435. # Let's check whether copy of our criterion has the same type
  1436. # and properties as original
  1437. n_outputs = 3
  1438. n_classes = np.arange(3, dtype=np.intp)
  1439. n_samples = 100
  1440. def _pickle_copy(obj):
  1441. return pickle.loads(pickle.dumps(obj))
  1442. for copy_func in [copy.copy, copy.deepcopy, _pickle_copy]:
  1443. for _, typename in CRITERIA_CLF.items():
  1444. criteria = typename(n_outputs, n_classes)
  1445. result = copy_func(criteria).__reduce__()
  1446. typename_, (n_outputs_, n_classes_), _ = result
  1447. assert typename == typename_
  1448. assert n_outputs == n_outputs_
  1449. assert_array_equal(n_classes, n_classes_)
  1450. for _, typename in CRITERIA_REG.items():
  1451. criteria = typename(n_outputs, n_samples)
  1452. result = copy_func(criteria).__reduce__()
  1453. typename_, (n_outputs_, n_samples_), _ = result
  1454. assert typename == typename_
  1455. assert n_outputs == n_outputs_
  1456. assert n_samples == n_samples_
  1457. def test_empty_leaf_infinite_threshold():
  1458. # try to make empty leaf by using near infinite value.
  1459. data = np.random.RandomState(0).randn(100, 11) * 2e38
  1460. data = np.nan_to_num(data.astype("float32"))
  1461. X_full = data[:, :-1]
  1462. X_sparse = csc_matrix(X_full)
  1463. y = data[:, -1]
  1464. for X in [X_full, X_sparse]:
  1465. tree = DecisionTreeRegressor(random_state=0).fit(X, y)
  1466. terminal_regions = tree.apply(X)
  1467. left_leaf = set(np.where(tree.tree_.children_left == TREE_LEAF)[0])
  1468. empty_leaf = left_leaf.difference(terminal_regions)
  1469. infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0]
  1470. assert len(infinite_threshold) == 0
  1471. assert len(empty_leaf) == 0
  1472. @pytest.mark.parametrize("criterion", CLF_CRITERIONS)
  1473. @pytest.mark.parametrize(
  1474. "dataset", sorted(set(DATASETS.keys()) - {"reg_small", "diabetes"})
  1475. )
  1476. @pytest.mark.parametrize("tree_cls", [DecisionTreeClassifier, ExtraTreeClassifier])
  1477. def test_prune_tree_classifier_are_subtrees(criterion, dataset, tree_cls):
  1478. dataset = DATASETS[dataset]
  1479. X, y = dataset["X"], dataset["y"]
  1480. est = tree_cls(max_leaf_nodes=20, random_state=0)
  1481. info = est.cost_complexity_pruning_path(X, y)
  1482. pruning_path = info.ccp_alphas
  1483. impurities = info.impurities
  1484. assert np.all(np.diff(pruning_path) >= 0)
  1485. assert np.all(np.diff(impurities) >= 0)
  1486. assert_pruning_creates_subtree(tree_cls, X, y, pruning_path)
  1487. @pytest.mark.parametrize("criterion", REG_CRITERIONS)
  1488. @pytest.mark.parametrize("dataset", DATASETS.keys())
  1489. @pytest.mark.parametrize("tree_cls", [DecisionTreeRegressor, ExtraTreeRegressor])
  1490. def test_prune_tree_regression_are_subtrees(criterion, dataset, tree_cls):
  1491. dataset = DATASETS[dataset]
  1492. X, y = dataset["X"], dataset["y"]
  1493. est = tree_cls(max_leaf_nodes=20, random_state=0)
  1494. info = est.cost_complexity_pruning_path(X, y)
  1495. pruning_path = info.ccp_alphas
  1496. impurities = info.impurities
  1497. assert np.all(np.diff(pruning_path) >= 0)
  1498. assert np.all(np.diff(impurities) >= 0)
  1499. assert_pruning_creates_subtree(tree_cls, X, y, pruning_path)
  1500. def test_prune_single_node_tree():
  1501. # single node tree
  1502. clf1 = DecisionTreeClassifier(random_state=0)
  1503. clf1.fit([[0], [1]], [0, 0])
  1504. # pruned single node tree
  1505. clf2 = DecisionTreeClassifier(random_state=0, ccp_alpha=10)
  1506. clf2.fit([[0], [1]], [0, 0])
  1507. assert_is_subtree(clf1.tree_, clf2.tree_)
  1508. def assert_pruning_creates_subtree(estimator_cls, X, y, pruning_path):
  1509. # generate trees with increasing alphas
  1510. estimators = []
  1511. for ccp_alpha in pruning_path:
  1512. est = estimator_cls(max_leaf_nodes=20, ccp_alpha=ccp_alpha, random_state=0).fit(
  1513. X, y
  1514. )
  1515. estimators.append(est)
  1516. # A pruned tree must be a subtree of the previous tree (which had a
  1517. # smaller ccp_alpha)
  1518. for prev_est, next_est in zip(estimators, estimators[1:]):
  1519. assert_is_subtree(prev_est.tree_, next_est.tree_)
  1520. def assert_is_subtree(tree, subtree):
  1521. assert tree.node_count >= subtree.node_count
  1522. assert tree.max_depth >= subtree.max_depth
  1523. tree_c_left = tree.children_left
  1524. tree_c_right = tree.children_right
  1525. subtree_c_left = subtree.children_left
  1526. subtree_c_right = subtree.children_right
  1527. stack = [(0, 0)]
  1528. while stack:
  1529. tree_node_idx, subtree_node_idx = stack.pop()
  1530. assert_array_almost_equal(
  1531. tree.value[tree_node_idx], subtree.value[subtree_node_idx]
  1532. )
  1533. assert_almost_equal(
  1534. tree.impurity[tree_node_idx], subtree.impurity[subtree_node_idx]
  1535. )
  1536. assert_almost_equal(
  1537. tree.n_node_samples[tree_node_idx], subtree.n_node_samples[subtree_node_idx]
  1538. )
  1539. assert_almost_equal(
  1540. tree.weighted_n_node_samples[tree_node_idx],
  1541. subtree.weighted_n_node_samples[subtree_node_idx],
  1542. )
  1543. if subtree_c_left[subtree_node_idx] == subtree_c_right[subtree_node_idx]:
  1544. # is a leaf
  1545. assert_almost_equal(TREE_UNDEFINED, subtree.threshold[subtree_node_idx])
  1546. else:
  1547. # not a leaf
  1548. assert_almost_equal(
  1549. tree.threshold[tree_node_idx], subtree.threshold[subtree_node_idx]
  1550. )
  1551. stack.append((tree_c_left[tree_node_idx], subtree_c_left[subtree_node_idx]))
  1552. stack.append(
  1553. (tree_c_right[tree_node_idx], subtree_c_right[subtree_node_idx])
  1554. )
  1555. @pytest.mark.parametrize("name", ALL_TREES)
  1556. @pytest.mark.parametrize("splitter", ["best", "random"])
  1557. @pytest.mark.parametrize("X_format", ["dense", "csr", "csc"])
  1558. def test_apply_path_readonly_all_trees(name, splitter, X_format):
  1559. dataset = DATASETS["clf_small"]
  1560. X_small = dataset["X"].astype(tree._tree.DTYPE, copy=False)
  1561. if X_format == "dense":
  1562. X_readonly = create_memmap_backed_data(X_small)
  1563. else:
  1564. X_readonly = dataset["X_sparse"] # CSR
  1565. if X_format == "csc":
  1566. # Cheap CSR to CSC conversion
  1567. X_readonly = X_readonly.tocsc()
  1568. X_readonly.data = np.array(X_readonly.data, dtype=tree._tree.DTYPE)
  1569. (
  1570. X_readonly.data,
  1571. X_readonly.indices,
  1572. X_readonly.indptr,
  1573. ) = create_memmap_backed_data(
  1574. (X_readonly.data, X_readonly.indices, X_readonly.indptr)
  1575. )
  1576. y_readonly = create_memmap_backed_data(np.array(y_small, dtype=tree._tree.DTYPE))
  1577. est = ALL_TREES[name](splitter=splitter)
  1578. est.fit(X_readonly, y_readonly)
  1579. assert_array_equal(est.predict(X_readonly), est.predict(X_small))
  1580. assert_array_equal(
  1581. est.decision_path(X_readonly).todense(), est.decision_path(X_small).todense()
  1582. )
  1583. @pytest.mark.parametrize("criterion", ["squared_error", "friedman_mse", "poisson"])
  1584. @pytest.mark.parametrize("Tree", REG_TREES.values())
  1585. def test_balance_property(criterion, Tree):
  1586. # Test that sum(y_pred)=sum(y_true) on training set.
  1587. # This works if the mean is predicted (should even be true for each leaf).
  1588. # MAE predicts the median and is therefore excluded from this test.
  1589. # Choose a training set with non-negative targets (for poisson)
  1590. X, y = diabetes.data, diabetes.target
  1591. reg = Tree(criterion=criterion)
  1592. reg.fit(X, y)
  1593. assert np.sum(reg.predict(X)) == pytest.approx(np.sum(y))
  1594. @pytest.mark.parametrize("seed", range(3))
  1595. def test_poisson_zero_nodes(seed):
  1596. # Test that sum(y)=0 and therefore y_pred=0 is forbidden on nodes.
  1597. X = [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 2], [1, 2], [1, 3]]
  1598. y = [0, 0, 0, 0, 1, 2, 3, 4]
  1599. # Note that X[:, 0] == 0 is a 100% indicator for y == 0. The tree can
  1600. # easily learn that:
  1601. reg = DecisionTreeRegressor(criterion="squared_error", random_state=seed)
  1602. reg.fit(X, y)
  1603. assert np.amin(reg.predict(X)) == 0
  1604. # whereas Poisson must predict strictly positive numbers
  1605. reg = DecisionTreeRegressor(criterion="poisson", random_state=seed)
  1606. reg.fit(X, y)
  1607. assert np.all(reg.predict(X) > 0)
  1608. # Test additional dataset where something could go wrong.
  1609. n_features = 10
  1610. X, y = datasets.make_regression(
  1611. effective_rank=n_features * 2 // 3,
  1612. tail_strength=0.6,
  1613. n_samples=1_000,
  1614. n_features=n_features,
  1615. n_informative=n_features * 2 // 3,
  1616. random_state=seed,
  1617. )
  1618. # some excess zeros
  1619. y[(-1 < y) & (y < 0)] = 0
  1620. # make sure the target is positive
  1621. y = np.abs(y)
  1622. reg = DecisionTreeRegressor(criterion="poisson", random_state=seed)
  1623. reg.fit(X, y)
  1624. assert np.all(reg.predict(X) > 0)
  1625. def test_poisson_vs_mse():
  1626. # For a Poisson distributed target, Poisson loss should give better results
  1627. # than squared error measured in Poisson deviance as metric.
  1628. # We have a similar test, test_poisson(), in
  1629. # sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py
  1630. rng = np.random.RandomState(42)
  1631. n_train, n_test, n_features = 500, 500, 10
  1632. X = datasets.make_low_rank_matrix(
  1633. n_samples=n_train + n_test, n_features=n_features, random_state=rng
  1634. )
  1635. # We create a log-linear Poisson model and downscale coef as it will get
  1636. # exponentiated.
  1637. coef = rng.uniform(low=-2, high=2, size=n_features) / np.max(X, axis=0)
  1638. y = rng.poisson(lam=np.exp(X @ coef))
  1639. X_train, X_test, y_train, y_test = train_test_split(
  1640. X, y, test_size=n_test, random_state=rng
  1641. )
  1642. # We prevent some overfitting by setting min_samples_split=10.
  1643. tree_poi = DecisionTreeRegressor(
  1644. criterion="poisson", min_samples_split=10, random_state=rng
  1645. )
  1646. tree_mse = DecisionTreeRegressor(
  1647. criterion="squared_error", min_samples_split=10, random_state=rng
  1648. )
  1649. tree_poi.fit(X_train, y_train)
  1650. tree_mse.fit(X_train, y_train)
  1651. dummy = DummyRegressor(strategy="mean").fit(X_train, y_train)
  1652. for X, y, val in [(X_train, y_train, "train"), (X_test, y_test, "test")]:
  1653. metric_poi = mean_poisson_deviance(y, tree_poi.predict(X))
  1654. # squared_error might produce non-positive predictions => clip
  1655. metric_mse = mean_poisson_deviance(y, np.clip(tree_mse.predict(X), 1e-15, None))
  1656. metric_dummy = mean_poisson_deviance(y, dummy.predict(X))
  1657. # As squared_error might correctly predict 0 in train set, its train
  1658. # score can be better than Poisson. This is no longer the case for the
  1659. # test set.
  1660. if val == "test":
  1661. assert metric_poi < 0.5 * metric_mse
  1662. assert metric_poi < 0.75 * metric_dummy
  1663. @pytest.mark.parametrize("criterion", REG_CRITERIONS)
  1664. def test_decision_tree_regressor_sample_weight_consistency(criterion):
  1665. """Test that the impact of sample_weight is consistent."""
  1666. tree_params = dict(criterion=criterion)
  1667. tree = DecisionTreeRegressor(**tree_params, random_state=42)
  1668. for kind in ["zeros", "ones"]:
  1669. check_sample_weights_invariance(
  1670. "DecisionTreeRegressor_" + criterion, tree, kind="zeros"
  1671. )
  1672. rng = np.random.RandomState(0)
  1673. n_samples, n_features = 10, 5
  1674. X = rng.rand(n_samples, n_features)
  1675. y = np.mean(X, axis=1) + rng.rand(n_samples)
  1676. # make it positive in order to work also for poisson criterion
  1677. y += np.min(y) + 0.1
  1678. # check that multiplying sample_weight by 2 is equivalent
  1679. # to repeating corresponding samples twice
  1680. X2 = np.concatenate([X, X[: n_samples // 2]], axis=0)
  1681. y2 = np.concatenate([y, y[: n_samples // 2]])
  1682. sample_weight_1 = np.ones(len(y))
  1683. sample_weight_1[: n_samples // 2] = 2
  1684. tree1 = DecisionTreeRegressor(**tree_params).fit(
  1685. X, y, sample_weight=sample_weight_1
  1686. )
  1687. tree2 = DecisionTreeRegressor(**tree_params).fit(X2, y2, sample_weight=None)
  1688. assert tree1.tree_.node_count == tree2.tree_.node_count
  1689. # Thresholds, tree.tree_.threshold, and values, tree.tree_.value, are not
  1690. # exactly the same, but on the training set, those differences do not
  1691. # matter and thus predictions are the same.
  1692. assert_allclose(tree1.predict(X), tree2.predict(X))
  1693. @pytest.mark.parametrize("Tree", [DecisionTreeClassifier, ExtraTreeClassifier])
  1694. @pytest.mark.parametrize("n_classes", [2, 4])
  1695. def test_criterion_entropy_same_as_log_loss(Tree, n_classes):
  1696. """Test that criterion=entropy gives same as log_loss."""
  1697. n_samples, n_features = 50, 5
  1698. X, y = datasets.make_classification(
  1699. n_classes=n_classes,
  1700. n_samples=n_samples,
  1701. n_features=n_features,
  1702. n_informative=n_features,
  1703. n_redundant=0,
  1704. random_state=42,
  1705. )
  1706. tree_log_loss = Tree(criterion="log_loss", random_state=43).fit(X, y)
  1707. tree_entropy = Tree(criterion="entropy", random_state=43).fit(X, y)
  1708. assert_tree_equal(
  1709. tree_log_loss.tree_,
  1710. tree_entropy.tree_,
  1711. f"{Tree!r} with criterion 'entropy' and 'log_loss' gave different trees.",
  1712. )
  1713. assert_allclose(tree_log_loss.predict(X), tree_entropy.predict(X))
  1714. def test_different_endianness_pickle():
  1715. X, y = datasets.make_classification(random_state=0)
  1716. clf = DecisionTreeClassifier(random_state=0, max_depth=3)
  1717. clf.fit(X, y)
  1718. score = clf.score(X, y)
  1719. def reduce_ndarray(arr):
  1720. return arr.byteswap().newbyteorder().__reduce__()
  1721. def get_pickle_non_native_endianness():
  1722. f = io.BytesIO()
  1723. p = pickle.Pickler(f)
  1724. p.dispatch_table = copyreg.dispatch_table.copy()
  1725. p.dispatch_table[np.ndarray] = reduce_ndarray
  1726. p.dump(clf)
  1727. f.seek(0)
  1728. return f
  1729. new_clf = pickle.load(get_pickle_non_native_endianness())
  1730. new_score = new_clf.score(X, y)
  1731. assert np.isclose(score, new_score)
  1732. def test_different_endianness_joblib_pickle():
  1733. X, y = datasets.make_classification(random_state=0)
  1734. clf = DecisionTreeClassifier(random_state=0, max_depth=3)
  1735. clf.fit(X, y)
  1736. score = clf.score(X, y)
  1737. class NonNativeEndiannessNumpyPickler(NumpyPickler):
  1738. def save(self, obj):
  1739. if isinstance(obj, np.ndarray):
  1740. obj = obj.byteswap().newbyteorder()
  1741. super().save(obj)
  1742. def get_joblib_pickle_non_native_endianness():
  1743. f = io.BytesIO()
  1744. p = NonNativeEndiannessNumpyPickler(f)
  1745. p.dump(clf)
  1746. f.seek(0)
  1747. return f
  1748. new_clf = joblib.load(get_joblib_pickle_non_native_endianness())
  1749. new_score = new_clf.score(X, y)
  1750. assert np.isclose(score, new_score)
  1751. def get_different_bitness_node_ndarray(node_ndarray):
  1752. new_dtype_for_indexing_fields = np.int64 if _IS_32BIT else np.int32
  1753. # field names in Node struct with SIZE_t types (see sklearn/tree/_tree.pxd)
  1754. indexing_field_names = ["left_child", "right_child", "feature", "n_node_samples"]
  1755. new_dtype_dict = {
  1756. name: dtype for name, (dtype, _) in node_ndarray.dtype.fields.items()
  1757. }
  1758. for name in indexing_field_names:
  1759. new_dtype_dict[name] = new_dtype_for_indexing_fields
  1760. new_dtype = np.dtype(
  1761. {"names": list(new_dtype_dict.keys()), "formats": list(new_dtype_dict.values())}
  1762. )
  1763. return node_ndarray.astype(new_dtype, casting="same_kind")
  1764. def get_different_alignment_node_ndarray(node_ndarray):
  1765. new_dtype_dict = {
  1766. name: dtype for name, (dtype, _) in node_ndarray.dtype.fields.items()
  1767. }
  1768. offsets = [offset for dtype, offset in node_ndarray.dtype.fields.values()]
  1769. shifted_offsets = [8 + offset for offset in offsets]
  1770. new_dtype = np.dtype(
  1771. {
  1772. "names": list(new_dtype_dict.keys()),
  1773. "formats": list(new_dtype_dict.values()),
  1774. "offsets": shifted_offsets,
  1775. }
  1776. )
  1777. return node_ndarray.astype(new_dtype, casting="same_kind")
  1778. def reduce_tree_with_different_bitness(tree):
  1779. new_dtype = np.int64 if _IS_32BIT else np.int32
  1780. tree_cls, (n_features, n_classes, n_outputs), state = tree.__reduce__()
  1781. new_n_classes = n_classes.astype(new_dtype, casting="same_kind")
  1782. new_state = state.copy()
  1783. new_state["nodes"] = get_different_bitness_node_ndarray(new_state["nodes"])
  1784. return (tree_cls, (n_features, new_n_classes, n_outputs), new_state)
  1785. def test_different_bitness_pickle():
  1786. X, y = datasets.make_classification(random_state=0)
  1787. clf = DecisionTreeClassifier(random_state=0, max_depth=3)
  1788. clf.fit(X, y)
  1789. score = clf.score(X, y)
  1790. def pickle_dump_with_different_bitness():
  1791. f = io.BytesIO()
  1792. p = pickle.Pickler(f)
  1793. p.dispatch_table = copyreg.dispatch_table.copy()
  1794. p.dispatch_table[CythonTree] = reduce_tree_with_different_bitness
  1795. p.dump(clf)
  1796. f.seek(0)
  1797. return f
  1798. new_clf = pickle.load(pickle_dump_with_different_bitness())
  1799. new_score = new_clf.score(X, y)
  1800. assert score == pytest.approx(new_score)
  1801. def test_different_bitness_joblib_pickle():
  1802. # Make sure that a platform specific pickle generated on a 64 bit
  1803. # platform can be converted at pickle load time into an estimator
  1804. # with Cython code that works with the host's native integer precision
  1805. # to index nodes in the tree data structure when the host is a 32 bit
  1806. # platform (and vice versa).
  1807. X, y = datasets.make_classification(random_state=0)
  1808. clf = DecisionTreeClassifier(random_state=0, max_depth=3)
  1809. clf.fit(X, y)
  1810. score = clf.score(X, y)
  1811. def joblib_dump_with_different_bitness():
  1812. f = io.BytesIO()
  1813. p = NumpyPickler(f)
  1814. p.dispatch_table = copyreg.dispatch_table.copy()
  1815. p.dispatch_table[CythonTree] = reduce_tree_with_different_bitness
  1816. p.dump(clf)
  1817. f.seek(0)
  1818. return f
  1819. new_clf = joblib.load(joblib_dump_with_different_bitness())
  1820. new_score = new_clf.score(X, y)
  1821. assert score == pytest.approx(new_score)
  1822. def test_check_n_classes():
  1823. expected_dtype = np.dtype(np.int32) if _IS_32BIT else np.dtype(np.int64)
  1824. allowed_dtypes = [np.dtype(np.int32), np.dtype(np.int64)]
  1825. allowed_dtypes += [dt.newbyteorder() for dt in allowed_dtypes]
  1826. n_classes = np.array([0, 1], dtype=expected_dtype)
  1827. for dt in allowed_dtypes:
  1828. _check_n_classes(n_classes.astype(dt), expected_dtype)
  1829. with pytest.raises(ValueError, match="Wrong dimensions.+n_classes"):
  1830. wrong_dim_n_classes = np.array([[0, 1]], dtype=expected_dtype)
  1831. _check_n_classes(wrong_dim_n_classes, expected_dtype)
  1832. with pytest.raises(ValueError, match="n_classes.+incompatible dtype"):
  1833. wrong_dtype_n_classes = n_classes.astype(np.float64)
  1834. _check_n_classes(wrong_dtype_n_classes, expected_dtype)
  1835. def test_check_value_ndarray():
  1836. expected_dtype = np.dtype(np.float64)
  1837. expected_shape = (5, 1, 2)
  1838. value_ndarray = np.zeros(expected_shape, dtype=expected_dtype)
  1839. allowed_dtypes = [expected_dtype, expected_dtype.newbyteorder()]
  1840. for dt in allowed_dtypes:
  1841. _check_value_ndarray(
  1842. value_ndarray, expected_dtype=dt, expected_shape=expected_shape
  1843. )
  1844. with pytest.raises(ValueError, match="Wrong shape.+value array"):
  1845. _check_value_ndarray(
  1846. value_ndarray, expected_dtype=expected_dtype, expected_shape=(1, 2)
  1847. )
  1848. for problematic_arr in [value_ndarray[:, :, :1], np.asfortranarray(value_ndarray)]:
  1849. with pytest.raises(ValueError, match="value array.+C-contiguous"):
  1850. _check_value_ndarray(
  1851. problematic_arr,
  1852. expected_dtype=expected_dtype,
  1853. expected_shape=problematic_arr.shape,
  1854. )
  1855. with pytest.raises(ValueError, match="value array.+incompatible dtype"):
  1856. _check_value_ndarray(
  1857. value_ndarray.astype(np.float32),
  1858. expected_dtype=expected_dtype,
  1859. expected_shape=expected_shape,
  1860. )
  1861. def test_check_node_ndarray():
  1862. expected_dtype = NODE_DTYPE
  1863. node_ndarray = np.zeros((5,), dtype=expected_dtype)
  1864. valid_node_ndarrays = [
  1865. node_ndarray,
  1866. get_different_bitness_node_ndarray(node_ndarray),
  1867. get_different_alignment_node_ndarray(node_ndarray),
  1868. ]
  1869. valid_node_ndarrays += [
  1870. arr.astype(arr.dtype.newbyteorder()) for arr in valid_node_ndarrays
  1871. ]
  1872. for arr in valid_node_ndarrays:
  1873. _check_node_ndarray(node_ndarray, expected_dtype=expected_dtype)
  1874. with pytest.raises(ValueError, match="Wrong dimensions.+node array"):
  1875. problematic_node_ndarray = np.zeros((5, 2), dtype=expected_dtype)
  1876. _check_node_ndarray(problematic_node_ndarray, expected_dtype=expected_dtype)
  1877. with pytest.raises(ValueError, match="node array.+C-contiguous"):
  1878. problematic_node_ndarray = node_ndarray[::2]
  1879. _check_node_ndarray(problematic_node_ndarray, expected_dtype=expected_dtype)
  1880. dtype_dict = {name: dtype for name, (dtype, _) in node_ndarray.dtype.fields.items()}
  1881. # array with wrong 'threshold' field dtype (int64 rather than float64)
  1882. new_dtype_dict = dtype_dict.copy()
  1883. new_dtype_dict["threshold"] = np.int64
  1884. new_dtype = np.dtype(
  1885. {"names": list(new_dtype_dict.keys()), "formats": list(new_dtype_dict.values())}
  1886. )
  1887. problematic_node_ndarray = node_ndarray.astype(new_dtype)
  1888. with pytest.raises(ValueError, match="node array.+incompatible dtype"):
  1889. _check_node_ndarray(problematic_node_ndarray, expected_dtype=expected_dtype)
  1890. # array with wrong 'left_child' field dtype (float64 rather than int64 or int32)
  1891. new_dtype_dict = dtype_dict.copy()
  1892. new_dtype_dict["left_child"] = np.float64
  1893. new_dtype = np.dtype(
  1894. {"names": list(new_dtype_dict.keys()), "formats": list(new_dtype_dict.values())}
  1895. )
  1896. problematic_node_ndarray = node_ndarray.astype(new_dtype)
  1897. with pytest.raises(ValueError, match="node array.+incompatible dtype"):
  1898. _check_node_ndarray(problematic_node_ndarray, expected_dtype=expected_dtype)
  1899. @pytest.mark.parametrize(
  1900. "Splitter", chain(DENSE_SPLITTERS.values(), SPARSE_SPLITTERS.values())
  1901. )
  1902. def test_splitter_serializable(Splitter):
  1903. """Check that splitters are serializable."""
  1904. rng = np.random.RandomState(42)
  1905. max_features = 10
  1906. n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp)
  1907. criterion = CRITERIA_CLF["gini"](n_outputs, n_classes)
  1908. splitter = Splitter(criterion, max_features, 5, 0.5, rng)
  1909. splitter_serialize = pickle.dumps(splitter)
  1910. splitter_back = pickle.loads(splitter_serialize)
  1911. assert splitter_back.max_features == max_features
  1912. assert isinstance(splitter_back, Splitter)
  1913. def test_tree_deserialization_from_read_only_buffer(tmpdir):
  1914. """Check that Trees can be deserialized with read only buffers.
  1915. Non-regression test for gh-25584.
  1916. """
  1917. pickle_path = str(tmpdir.join("clf.joblib"))
  1918. clf = DecisionTreeClassifier(random_state=0)
  1919. clf.fit(X_small, y_small)
  1920. joblib.dump(clf, pickle_path)
  1921. loaded_clf = joblib.load(pickle_path, mmap_mode="r")
  1922. assert_tree_equal(
  1923. loaded_clf.tree_,
  1924. clf.tree_,
  1925. "The trees of the original and loaded classifiers are not equal.",
  1926. )
  1927. @pytest.mark.parametrize("Tree", ALL_TREES.values())
  1928. def test_min_sample_split_1_error(Tree):
  1929. """Check that an error is raised when min_sample_split=1.
  1930. non-regression test for issue gh-25481.
  1931. """
  1932. X = np.array([[0, 0], [1, 1]])
  1933. y = np.array([0, 1])
  1934. # min_samples_split=1.0 is valid
  1935. Tree(min_samples_split=1.0).fit(X, y)
  1936. # min_samples_split=1 is invalid
  1937. tree = Tree(min_samples_split=1)
  1938. msg = (
  1939. r"'min_samples_split' .* must be an int in the range \[2, inf\) "
  1940. r"or a float in the range \(0.0, 1.0\]"
  1941. )
  1942. with pytest.raises(ValueError, match=msg):
  1943. tree.fit(X, y)
  1944. @pytest.mark.parametrize("criterion", ["squared_error", "friedman_mse"])
  1945. def test_missing_values_on_equal_nodes_no_missing(criterion):
  1946. """Check missing values goes to correct node during predictions"""
  1947. X = np.array([[0, 1, 2, 3, 8, 9, 11, 12, 15]]).T
  1948. y = np.array([0.1, 0.2, 0.3, 0.2, 1.4, 1.4, 1.5, 1.6, 2.6])
  1949. dtc = DecisionTreeRegressor(random_state=42, max_depth=1, criterion=criterion)
  1950. dtc.fit(X, y)
  1951. # Goes to right node because it has the most data points
  1952. y_pred = dtc.predict([[np.nan]])
  1953. assert_allclose(y_pred, [np.mean(y[-5:])])
  1954. # equal number of elements in both nodes
  1955. X_equal = X[:-1]
  1956. y_equal = y[:-1]
  1957. dtc = DecisionTreeRegressor(random_state=42, max_depth=1, criterion=criterion)
  1958. dtc.fit(X_equal, y_equal)
  1959. # Goes to right node because the implementation sets:
  1960. # missing_go_to_left = n_left > n_right, which is False
  1961. y_pred = dtc.predict([[np.nan]])
  1962. assert_allclose(y_pred, [np.mean(y_equal[-4:])])
  1963. @pytest.mark.parametrize("criterion", ["entropy", "gini"])
  1964. def test_missing_values_best_splitter_three_classes(criterion):
  1965. """Test when missing values are uniquely present in a class among 3 classes."""
  1966. missing_values_class = 0
  1967. X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 8, 9, 11, 12]]).T
  1968. y = np.array([missing_values_class] * 4 + [1] * 4 + [2] * 4)
  1969. dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion)
  1970. dtc.fit(X, y)
  1971. X_test = np.array([[np.nan, 3, 12]]).T
  1972. y_nan_pred = dtc.predict(X_test)
  1973. # Missing values necessarily are associated to the observed class.
  1974. assert_array_equal(y_nan_pred, [missing_values_class, 1, 2])
  1975. @pytest.mark.parametrize("criterion", ["entropy", "gini"])
  1976. def test_missing_values_best_splitter_to_left(criterion):
  1977. """Missing values spanning only one class at fit-time must make missing
  1978. values at predict-time be classified has belonging to this class."""
  1979. X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 4, 5]]).T
  1980. y = np.array([0] * 4 + [1] * 6)
  1981. dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion)
  1982. dtc.fit(X, y)
  1983. X_test = np.array([[np.nan, 5, np.nan]]).T
  1984. y_pred = dtc.predict(X_test)
  1985. assert_array_equal(y_pred, [0, 1, 0])
  1986. @pytest.mark.parametrize("criterion", ["entropy", "gini"])
  1987. def test_missing_values_best_splitter_to_right(criterion):
  1988. """Missing values and non-missing values sharing one class at fit-time
  1989. must make missing values at predict-time be classified has belonging
  1990. to this class."""
  1991. X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 4, 5]]).T
  1992. y = np.array([1] * 4 + [0] * 4 + [1] * 2)
  1993. dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion)
  1994. dtc.fit(X, y)
  1995. X_test = np.array([[np.nan, 1.2, 4.8]]).T
  1996. y_pred = dtc.predict(X_test)
  1997. assert_array_equal(y_pred, [1, 0, 1])
  1998. @pytest.mark.parametrize("criterion", ["entropy", "gini"])
  1999. def test_missing_values_missing_both_classes_has_nan(criterion):
  2000. """Check behavior of missing value when there is one missing value in each class."""
  2001. X = np.array([[1, 2, 3, 5, np.nan, 10, 20, 30, 60, np.nan]]).T
  2002. y = np.array([0] * 5 + [1] * 5)
  2003. dtc = DecisionTreeClassifier(random_state=42, max_depth=1, criterion=criterion)
  2004. dtc.fit(X, y)
  2005. X_test = np.array([[np.nan, 2.3, 34.2]]).T
  2006. y_pred = dtc.predict(X_test)
  2007. # Missing value goes to the class at the right (here 1) because the implementation
  2008. # searches right first.
  2009. assert_array_equal(y_pred, [1, 0, 1])
  2010. @pytest.mark.parametrize("is_sparse", [True, False])
  2011. @pytest.mark.parametrize(
  2012. "tree",
  2013. [
  2014. DecisionTreeClassifier(splitter="random"),
  2015. DecisionTreeRegressor(criterion="absolute_error"),
  2016. ],
  2017. )
  2018. def test_missing_value_errors(is_sparse, tree):
  2019. """Check unsupported configurations for missing values."""
  2020. X = np.array([[1, 2, 3, 5, np.nan, 10, 20, 30, 60, np.nan]]).T
  2021. y = np.array([0] * 5 + [1] * 5)
  2022. if is_sparse:
  2023. X = csr_matrix(X)
  2024. with pytest.raises(ValueError, match="Input X contains NaN"):
  2025. tree.fit(X, y)
  2026. def test_missing_values_poisson():
  2027. """Smoke test for poisson regression and missing values."""
  2028. X, y = diabetes.data.copy(), diabetes.target
  2029. # Set some values missing
  2030. X[::5, 0] = np.nan
  2031. X[::6, -1] = np.nan
  2032. reg = DecisionTreeRegressor(criterion="poisson", random_state=42)
  2033. reg.fit(X, y)
  2034. y_pred = reg.predict(X)
  2035. assert (y_pred >= 0.0).all()
  2036. @pytest.mark.parametrize(
  2037. "make_data, Tree",
  2038. [
  2039. (datasets.make_regression, DecisionTreeRegressor),
  2040. (datasets.make_classification, DecisionTreeClassifier),
  2041. ],
  2042. )
  2043. @pytest.mark.parametrize("sample_weight_train", [None, "ones"])
  2044. def test_missing_values_is_resilience(make_data, Tree, sample_weight_train):
  2045. """Check that trees can deal with missing values and have decent performance."""
  2046. rng = np.random.RandomState(0)
  2047. n_samples, n_features = 1000, 50
  2048. X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
  2049. # Create dataset with missing values
  2050. X_missing = X.copy()
  2051. X_missing[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan
  2052. X_missing_train, X_missing_test, y_train, y_test = train_test_split(
  2053. X_missing, y, random_state=0
  2054. )
  2055. if sample_weight_train == "ones":
  2056. sample_weight_train = np.ones(X_missing_train.shape[0])
  2057. # Train tree with missing values
  2058. tree_with_missing = Tree(random_state=rng)
  2059. tree_with_missing.fit(X_missing_train, y_train, sample_weight=sample_weight_train)
  2060. score_with_missing = tree_with_missing.score(X_missing_test, y_test)
  2061. # Train tree without missing values
  2062. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
  2063. tree = Tree(random_state=rng)
  2064. tree.fit(X_train, y_train, sample_weight=sample_weight_train)
  2065. score_without_missing = tree.score(X_test, y_test)
  2066. # Score is still 90 percent of the tree's score that had no missing values
  2067. assert score_with_missing >= 0.9 * score_without_missing
  2068. def test_missing_value_is_predictive():
  2069. """Check the tree learns when only the missing value is predictive."""
  2070. rng = np.random.RandomState(0)
  2071. n_samples = 1000
  2072. X = rng.standard_normal(size=(n_samples, 10))
  2073. y = rng.randint(0, high=2, size=n_samples)
  2074. # Create a predictive feature using `y` and with some noise
  2075. X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05])
  2076. y_mask = y.copy().astype(bool)
  2077. y_mask[X_random_mask] = ~y_mask[X_random_mask]
  2078. X_predictive = rng.standard_normal(size=n_samples)
  2079. X_predictive[y_mask] = np.nan
  2080. X[:, 5] = X_predictive
  2081. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
  2082. tree = DecisionTreeClassifier(random_state=rng).fit(X_train, y_train)
  2083. assert tree.score(X_train, y_train) >= 0.85
  2084. assert tree.score(X_test, y_test) >= 0.85
  2085. @pytest.mark.parametrize(
  2086. "make_data, Tree",
  2087. [
  2088. (datasets.make_regression, DecisionTreeRegressor),
  2089. (datasets.make_classification, DecisionTreeClassifier),
  2090. ],
  2091. )
  2092. def test_sample_weight_non_uniform(make_data, Tree):
  2093. """Check sample weight is correctly handled with missing values."""
  2094. rng = np.random.RandomState(0)
  2095. n_samples, n_features = 1000, 10
  2096. X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
  2097. # Create dataset with missing values
  2098. X[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan
  2099. # Zero sample weight is the same as removing the sample
  2100. sample_weight = np.ones(X.shape[0])
  2101. sample_weight[::2] = 0.0
  2102. tree_with_sw = Tree(random_state=0)
  2103. tree_with_sw.fit(X, y, sample_weight=sample_weight)
  2104. tree_samples_removed = Tree(random_state=0)
  2105. tree_samples_removed.fit(X[1::2, :], y[1::2])
  2106. assert_allclose(tree_samples_removed.predict(X), tree_with_sw.predict(X))
  2107. def test_deterministic_pickle():
  2108. # Non-regression test for:
  2109. # https://github.com/scikit-learn/scikit-learn/issues/27268
  2110. # Uninitialised memory would lead to the two pickle strings being different.
  2111. tree1 = DecisionTreeClassifier(random_state=0).fit(iris.data, iris.target)
  2112. tree2 = DecisionTreeClassifier(random_state=0).fit(iris.data, iris.target)
  2113. pickle1 = pickle.dumps(tree1)
  2114. pickle2 = pickle.dumps(tree2)
  2115. assert pickle1 == pickle2