test_split.py 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950
  1. """Test the split module"""
  2. import re
  3. import warnings
  4. from itertools import combinations, combinations_with_replacement, permutations
  5. import numpy as np
  6. import pytest
  7. from scipy import stats
  8. from scipy.sparse import (
  9. coo_matrix,
  10. csc_matrix,
  11. csr_matrix,
  12. issparse,
  13. )
  14. from scipy.special import comb
  15. from sklearn.datasets import load_digits, make_classification
  16. from sklearn.dummy import DummyClassifier
  17. from sklearn.model_selection import (
  18. GridSearchCV,
  19. GroupKFold,
  20. GroupShuffleSplit,
  21. KFold,
  22. LeaveOneGroupOut,
  23. LeaveOneOut,
  24. LeavePGroupsOut,
  25. LeavePOut,
  26. PredefinedSplit,
  27. RepeatedKFold,
  28. RepeatedStratifiedKFold,
  29. ShuffleSplit,
  30. StratifiedGroupKFold,
  31. StratifiedKFold,
  32. StratifiedShuffleSplit,
  33. TimeSeriesSplit,
  34. check_cv,
  35. cross_val_score,
  36. train_test_split,
  37. )
  38. from sklearn.model_selection._split import (
  39. _build_repr,
  40. _validate_shuffle_split,
  41. _yields_constant_splits,
  42. )
  43. from sklearn.svm import SVC
  44. from sklearn.tests.test_metadata_routing import assert_request_is_empty
  45. from sklearn.utils._mocking import MockDataFrame
  46. from sklearn.utils._testing import (
  47. assert_allclose,
  48. assert_array_almost_equal,
  49. assert_array_equal,
  50. ignore_warnings,
  51. )
  52. from sklearn.utils.validation import _num_samples
  53. NO_GROUP_SPLITTERS = [
  54. KFold(),
  55. StratifiedKFold(),
  56. TimeSeriesSplit(),
  57. LeaveOneOut(),
  58. LeavePOut(p=2),
  59. ShuffleSplit(),
  60. StratifiedShuffleSplit(test_size=0.5),
  61. PredefinedSplit([1, 1, 2, 2]),
  62. RepeatedKFold(),
  63. RepeatedStratifiedKFold(),
  64. ]
  65. GROUP_SPLITTERS = [
  66. GroupKFold(),
  67. LeavePGroupsOut(n_groups=1),
  68. StratifiedGroupKFold(),
  69. LeaveOneGroupOut(),
  70. GroupShuffleSplit(),
  71. ]
  72. ALL_SPLITTERS = NO_GROUP_SPLITTERS + GROUP_SPLITTERS # type: ignore
  73. X = np.ones(10)
  74. y = np.arange(10) // 2
  75. P_sparse = coo_matrix(np.eye(5))
  76. test_groups = (
  77. np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
  78. np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
  79. np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
  80. np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
  81. [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
  82. ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
  83. )
  84. digits = load_digits()
  85. @ignore_warnings
  86. def test_cross_validator_with_default_params():
  87. n_samples = 4
  88. n_unique_groups = 4
  89. n_splits = 2
  90. p = 2
  91. n_shuffle_splits = 10 # (the default value)
  92. X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
  93. X_1d = np.array([1, 2, 3, 4])
  94. y = np.array([1, 1, 2, 2])
  95. groups = np.array([1, 2, 3, 4])
  96. loo = LeaveOneOut()
  97. lpo = LeavePOut(p)
  98. kf = KFold(n_splits)
  99. skf = StratifiedKFold(n_splits)
  100. lolo = LeaveOneGroupOut()
  101. lopo = LeavePGroupsOut(p)
  102. ss = ShuffleSplit(random_state=0)
  103. ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2
  104. sgkf = StratifiedGroupKFold(n_splits)
  105. loo_repr = "LeaveOneOut()"
  106. lpo_repr = "LeavePOut(p=2)"
  107. kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
  108. skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
  109. lolo_repr = "LeaveOneGroupOut()"
  110. lopo_repr = "LeavePGroupsOut(n_groups=2)"
  111. ss_repr = (
  112. "ShuffleSplit(n_splits=10, random_state=0, test_size=None, train_size=None)"
  113. )
  114. ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
  115. sgkf_repr = "StratifiedGroupKFold(n_splits=2, random_state=None, shuffle=False)"
  116. n_splits_expected = [
  117. n_samples,
  118. comb(n_samples, p),
  119. n_splits,
  120. n_splits,
  121. n_unique_groups,
  122. comb(n_unique_groups, p),
  123. n_shuffle_splits,
  124. 2,
  125. n_splits,
  126. ]
  127. for i, (cv, cv_repr) in enumerate(
  128. zip(
  129. [loo, lpo, kf, skf, lolo, lopo, ss, ps, sgkf],
  130. [
  131. loo_repr,
  132. lpo_repr,
  133. kf_repr,
  134. skf_repr,
  135. lolo_repr,
  136. lopo_repr,
  137. ss_repr,
  138. ps_repr,
  139. sgkf_repr,
  140. ],
  141. )
  142. ):
  143. # Test if get_n_splits works correctly
  144. assert n_splits_expected[i] == cv.get_n_splits(X, y, groups)
  145. # Test if the cross-validator works as expected even if
  146. # the data is 1d
  147. np.testing.assert_equal(
  148. list(cv.split(X, y, groups)), list(cv.split(X_1d, y, groups))
  149. )
  150. # Test that train, test indices returned are integers
  151. for train, test in cv.split(X, y, groups):
  152. assert np.asarray(train).dtype.kind == "i"
  153. assert np.asarray(test).dtype.kind == "i"
  154. # Test if the repr works without any errors
  155. assert cv_repr == repr(cv)
  156. # ValueError for get_n_splits methods
  157. msg = "The 'X' parameter should not be None."
  158. with pytest.raises(ValueError, match=msg):
  159. loo.get_n_splits(None, y, groups)
  160. with pytest.raises(ValueError, match=msg):
  161. lpo.get_n_splits(None, y, groups)
  162. def test_2d_y():
  163. # smoke test for 2d y and multi-label
  164. n_samples = 30
  165. rng = np.random.RandomState(1)
  166. X = rng.randint(0, 3, size=(n_samples, 2))
  167. y = rng.randint(0, 3, size=(n_samples,))
  168. y_2d = y.reshape(-1, 1)
  169. y_multilabel = rng.randint(0, 2, size=(n_samples, 3))
  170. groups = rng.randint(0, 3, size=(n_samples,))
  171. splitters = [
  172. LeaveOneOut(),
  173. LeavePOut(p=2),
  174. KFold(),
  175. StratifiedKFold(),
  176. RepeatedKFold(),
  177. RepeatedStratifiedKFold(),
  178. StratifiedGroupKFold(),
  179. ShuffleSplit(),
  180. StratifiedShuffleSplit(test_size=0.5),
  181. GroupShuffleSplit(),
  182. LeaveOneGroupOut(),
  183. LeavePGroupsOut(n_groups=2),
  184. GroupKFold(n_splits=3),
  185. TimeSeriesSplit(),
  186. PredefinedSplit(test_fold=groups),
  187. ]
  188. for splitter in splitters:
  189. list(splitter.split(X, y, groups))
  190. list(splitter.split(X, y_2d, groups))
  191. try:
  192. list(splitter.split(X, y_multilabel, groups))
  193. except ValueError as e:
  194. allowed_target_types = ("binary", "multiclass")
  195. msg = "Supported target types are: {}. Got 'multilabel".format(
  196. allowed_target_types
  197. )
  198. assert msg in str(e)
  199. def check_valid_split(train, test, n_samples=None):
  200. # Use python sets to get more informative assertion failure messages
  201. train, test = set(train), set(test)
  202. # Train and test split should not overlap
  203. assert train.intersection(test) == set()
  204. if n_samples is not None:
  205. # Check that the union of train an test split cover all the indices
  206. assert train.union(test) == set(range(n_samples))
  207. def check_cv_coverage(cv, X, y, groups, expected_n_splits):
  208. n_samples = _num_samples(X)
  209. # Check that a all the samples appear at least once in a test fold
  210. assert cv.get_n_splits(X, y, groups) == expected_n_splits
  211. collected_test_samples = set()
  212. iterations = 0
  213. for train, test in cv.split(X, y, groups):
  214. check_valid_split(train, test, n_samples=n_samples)
  215. iterations += 1
  216. collected_test_samples.update(test)
  217. # Check that the accumulated test samples cover the whole dataset
  218. assert iterations == expected_n_splits
  219. if n_samples is not None:
  220. assert collected_test_samples == set(range(n_samples))
  221. def test_kfold_valueerrors():
  222. X1 = np.array([[1, 2], [3, 4], [5, 6]])
  223. X2 = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
  224. # Check that errors are raised if there is not enough samples
  225. (ValueError, next, KFold(4).split(X1))
  226. # Check that a warning is raised if the least populated class has too few
  227. # members.
  228. y = np.array([3, 3, -1, -1, 3])
  229. skf_3 = StratifiedKFold(3)
  230. with pytest.warns(Warning, match="The least populated class"):
  231. next(skf_3.split(X2, y))
  232. sgkf_3 = StratifiedGroupKFold(3)
  233. naive_groups = np.arange(len(y))
  234. with pytest.warns(Warning, match="The least populated class"):
  235. next(sgkf_3.split(X2, y, naive_groups))
  236. # Check that despite the warning the folds are still computed even
  237. # though all the classes are not necessarily represented at on each
  238. # side of the split at each split
  239. with warnings.catch_warnings():
  240. warnings.simplefilter("ignore")
  241. check_cv_coverage(skf_3, X2, y, groups=None, expected_n_splits=3)
  242. with warnings.catch_warnings():
  243. warnings.simplefilter("ignore")
  244. check_cv_coverage(sgkf_3, X2, y, groups=naive_groups, expected_n_splits=3)
  245. # Check that errors are raised if all n_groups for individual
  246. # classes are less than n_splits.
  247. y = np.array([3, 3, -1, -1, 2])
  248. with pytest.raises(ValueError):
  249. next(skf_3.split(X2, y))
  250. with pytest.raises(ValueError):
  251. next(sgkf_3.split(X2, y))
  252. # Error when number of folds is <= 1
  253. with pytest.raises(ValueError):
  254. KFold(0)
  255. with pytest.raises(ValueError):
  256. KFold(1)
  257. error_string = "k-fold cross-validation requires at least one train/test split"
  258. with pytest.raises(ValueError, match=error_string):
  259. StratifiedKFold(0)
  260. with pytest.raises(ValueError, match=error_string):
  261. StratifiedKFold(1)
  262. with pytest.raises(ValueError, match=error_string):
  263. StratifiedGroupKFold(0)
  264. with pytest.raises(ValueError, match=error_string):
  265. StratifiedGroupKFold(1)
  266. # When n_splits is not integer:
  267. with pytest.raises(ValueError):
  268. KFold(1.5)
  269. with pytest.raises(ValueError):
  270. KFold(2.0)
  271. with pytest.raises(ValueError):
  272. StratifiedKFold(1.5)
  273. with pytest.raises(ValueError):
  274. StratifiedKFold(2.0)
  275. with pytest.raises(ValueError):
  276. StratifiedGroupKFold(1.5)
  277. with pytest.raises(ValueError):
  278. StratifiedGroupKFold(2.0)
  279. # When shuffle is not a bool:
  280. with pytest.raises(TypeError):
  281. KFold(n_splits=4, shuffle=None)
  282. def test_kfold_indices():
  283. # Check all indices are returned in the test folds
  284. X1 = np.ones(18)
  285. kf = KFold(3)
  286. check_cv_coverage(kf, X1, y=None, groups=None, expected_n_splits=3)
  287. # Check all indices are returned in the test folds even when equal-sized
  288. # folds are not possible
  289. X2 = np.ones(17)
  290. kf = KFold(3)
  291. check_cv_coverage(kf, X2, y=None, groups=None, expected_n_splits=3)
  292. # Check if get_n_splits returns the number of folds
  293. assert 5 == KFold(5).get_n_splits(X2)
  294. def test_kfold_no_shuffle():
  295. # Manually check that KFold preserves the data ordering on toy datasets
  296. X2 = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
  297. splits = KFold(2).split(X2[:-1])
  298. train, test = next(splits)
  299. assert_array_equal(test, [0, 1])
  300. assert_array_equal(train, [2, 3])
  301. train, test = next(splits)
  302. assert_array_equal(test, [2, 3])
  303. assert_array_equal(train, [0, 1])
  304. splits = KFold(2).split(X2)
  305. train, test = next(splits)
  306. assert_array_equal(test, [0, 1, 2])
  307. assert_array_equal(train, [3, 4])
  308. train, test = next(splits)
  309. assert_array_equal(test, [3, 4])
  310. assert_array_equal(train, [0, 1, 2])
  311. def test_stratified_kfold_no_shuffle():
  312. # Manually check that StratifiedKFold preserves the data ordering as much
  313. # as possible on toy datasets in order to avoid hiding sample dependencies
  314. # when possible
  315. X, y = np.ones(4), [1, 1, 0, 0]
  316. splits = StratifiedKFold(2).split(X, y)
  317. train, test = next(splits)
  318. assert_array_equal(test, [0, 2])
  319. assert_array_equal(train, [1, 3])
  320. train, test = next(splits)
  321. assert_array_equal(test, [1, 3])
  322. assert_array_equal(train, [0, 2])
  323. X, y = np.ones(7), [1, 1, 1, 0, 0, 0, 0]
  324. splits = StratifiedKFold(2).split(X, y)
  325. train, test = next(splits)
  326. assert_array_equal(test, [0, 1, 3, 4])
  327. assert_array_equal(train, [2, 5, 6])
  328. train, test = next(splits)
  329. assert_array_equal(test, [2, 5, 6])
  330. assert_array_equal(train, [0, 1, 3, 4])
  331. # Check if get_n_splits returns the number of folds
  332. assert 5 == StratifiedKFold(5).get_n_splits(X, y)
  333. # Make sure string labels are also supported
  334. X = np.ones(7)
  335. y1 = ["1", "1", "1", "0", "0", "0", "0"]
  336. y2 = [1, 1, 1, 0, 0, 0, 0]
  337. np.testing.assert_equal(
  338. list(StratifiedKFold(2).split(X, y1)), list(StratifiedKFold(2).split(X, y2))
  339. )
  340. # Check equivalence to KFold
  341. y = [0, 1, 0, 1, 0, 1, 0, 1]
  342. X = np.ones_like(y)
  343. np.testing.assert_equal(
  344. list(StratifiedKFold(3).split(X, y)), list(KFold(3).split(X, y))
  345. )
  346. @pytest.mark.parametrize("shuffle", [False, True])
  347. @pytest.mark.parametrize("k", [4, 5, 6, 7, 8, 9, 10])
  348. @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
  349. def test_stratified_kfold_ratios(k, shuffle, kfold):
  350. # Check that stratified kfold preserves class ratios in individual splits
  351. # Repeat with shuffling turned off and on
  352. n_samples = 1000
  353. X = np.ones(n_samples)
  354. y = np.array(
  355. [4] * int(0.10 * n_samples)
  356. + [0] * int(0.89 * n_samples)
  357. + [1] * int(0.01 * n_samples)
  358. )
  359. # ensure perfect stratification with StratifiedGroupKFold
  360. groups = np.arange(len(y))
  361. distr = np.bincount(y) / len(y)
  362. test_sizes = []
  363. random_state = None if not shuffle else 0
  364. skf = kfold(k, random_state=random_state, shuffle=shuffle)
  365. for train, test in skf.split(X, y, groups=groups):
  366. assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
  367. assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
  368. test_sizes.append(len(test))
  369. assert np.ptp(test_sizes) <= 1
  370. @pytest.mark.parametrize("shuffle", [False, True])
  371. @pytest.mark.parametrize("k", [4, 6, 7])
  372. @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
  373. def test_stratified_kfold_label_invariance(k, shuffle, kfold):
  374. # Check that stratified kfold gives the same indices regardless of labels
  375. n_samples = 100
  376. y = np.array(
  377. [2] * int(0.10 * n_samples)
  378. + [0] * int(0.89 * n_samples)
  379. + [1] * int(0.01 * n_samples)
  380. )
  381. X = np.ones(len(y))
  382. # ensure perfect stratification with StratifiedGroupKFold
  383. groups = np.arange(len(y))
  384. def get_splits(y):
  385. random_state = None if not shuffle else 0
  386. return [
  387. (list(train), list(test))
  388. for train, test in kfold(
  389. k, random_state=random_state, shuffle=shuffle
  390. ).split(X, y, groups=groups)
  391. ]
  392. splits_base = get_splits(y)
  393. for perm in permutations([0, 1, 2]):
  394. y_perm = np.take(perm, y)
  395. splits_perm = get_splits(y_perm)
  396. assert splits_perm == splits_base
  397. def test_kfold_balance():
  398. # Check that KFold returns folds with balanced sizes
  399. for i in range(11, 17):
  400. kf = KFold(5).split(X=np.ones(i))
  401. sizes = [len(test) for _, test in kf]
  402. assert (np.max(sizes) - np.min(sizes)) <= 1
  403. assert np.sum(sizes) == i
  404. @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
  405. def test_stratifiedkfold_balance(kfold):
  406. # Check that KFold returns folds with balanced sizes (only when
  407. # stratification is possible)
  408. # Repeat with shuffling turned off and on
  409. X = np.ones(17)
  410. y = [0] * 3 + [1] * 14
  411. # ensure perfect stratification with StratifiedGroupKFold
  412. groups = np.arange(len(y))
  413. for shuffle in (True, False):
  414. cv = kfold(3, shuffle=shuffle)
  415. for i in range(11, 17):
  416. skf = cv.split(X[:i], y[:i], groups[:i])
  417. sizes = [len(test) for _, test in skf]
  418. assert (np.max(sizes) - np.min(sizes)) <= 1
  419. assert np.sum(sizes) == i
  420. def test_shuffle_kfold():
  421. # Check the indices are shuffled properly
  422. kf = KFold(3)
  423. kf2 = KFold(3, shuffle=True, random_state=0)
  424. kf3 = KFold(3, shuffle=True, random_state=1)
  425. X = np.ones(300)
  426. all_folds = np.zeros(300)
  427. for (tr1, te1), (tr2, te2), (tr3, te3) in zip(
  428. kf.split(X), kf2.split(X), kf3.split(X)
  429. ):
  430. for tr_a, tr_b in combinations((tr1, tr2, tr3), 2):
  431. # Assert that there is no complete overlap
  432. assert len(np.intersect1d(tr_a, tr_b)) != len(tr1)
  433. # Set all test indices in successive iterations of kf2 to 1
  434. all_folds[te2] = 1
  435. # Check that all indices are returned in the different test folds
  436. assert sum(all_folds) == 300
  437. @pytest.mark.parametrize("kfold", [KFold, StratifiedKFold, StratifiedGroupKFold])
  438. def test_shuffle_kfold_stratifiedkfold_reproducibility(kfold):
  439. X = np.ones(15) # Divisible by 3
  440. y = [0] * 7 + [1] * 8
  441. groups_1 = np.arange(len(y))
  442. X2 = np.ones(16) # Not divisible by 3
  443. y2 = [0] * 8 + [1] * 8
  444. groups_2 = np.arange(len(y2))
  445. # Check that when the shuffle is True, multiple split calls produce the
  446. # same split when random_state is int
  447. kf = kfold(3, shuffle=True, random_state=0)
  448. np.testing.assert_equal(
  449. list(kf.split(X, y, groups_1)), list(kf.split(X, y, groups_1))
  450. )
  451. # Check that when the shuffle is True, multiple split calls often
  452. # (not always) produce different splits when random_state is
  453. # RandomState instance or None
  454. kf = kfold(3, shuffle=True, random_state=np.random.RandomState(0))
  455. for data in zip((X, X2), (y, y2), (groups_1, groups_2)):
  456. # Test if the two splits are different cv
  457. for (_, test_a), (_, test_b) in zip(kf.split(*data), kf.split(*data)):
  458. # cv.split(...) returns an array of tuples, each tuple
  459. # consisting of an array with train indices and test indices
  460. # Ensure that the splits for data are not same
  461. # when random state is not set
  462. with pytest.raises(AssertionError):
  463. np.testing.assert_array_equal(test_a, test_b)
  464. def test_shuffle_stratifiedkfold():
  465. # Check that shuffling is happening when requested, and for proper
  466. # sample coverage
  467. X_40 = np.ones(40)
  468. y = [0] * 20 + [1] * 20
  469. kf0 = StratifiedKFold(5, shuffle=True, random_state=0)
  470. kf1 = StratifiedKFold(5, shuffle=True, random_state=1)
  471. for (_, test0), (_, test1) in zip(kf0.split(X_40, y), kf1.split(X_40, y)):
  472. assert set(test0) != set(test1)
  473. check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
  474. # Ensure that we shuffle each class's samples with different
  475. # random_state in StratifiedKFold
  476. # See https://github.com/scikit-learn/scikit-learn/pull/13124
  477. X = np.arange(10)
  478. y = [0] * 5 + [1] * 5
  479. kf1 = StratifiedKFold(5, shuffle=True, random_state=0)
  480. kf2 = StratifiedKFold(5, shuffle=True, random_state=1)
  481. test_set1 = sorted([tuple(s[1]) for s in kf1.split(X, y)])
  482. test_set2 = sorted([tuple(s[1]) for s in kf2.split(X, y)])
  483. assert test_set1 != test_set2
  484. def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
  485. # The digits samples are dependent: they are apparently grouped by authors
  486. # although we don't have any information on the groups segment locations
  487. # for this data. We can highlight this fact by computing k-fold cross-
  488. # validation with and without shuffling: we observe that the shuffling case
  489. # wrongly makes the IID assumption and is therefore too optimistic: it
  490. # estimates a much higher accuracy (around 0.93) than that the non
  491. # shuffling variant (around 0.81).
  492. X, y = digits.data[:600], digits.target[:600]
  493. model = SVC(C=10, gamma=0.005)
  494. n_splits = 3
  495. cv = KFold(n_splits=n_splits, shuffle=False)
  496. mean_score = cross_val_score(model, X, y, cv=cv).mean()
  497. assert 0.92 > mean_score
  498. assert mean_score > 0.80
  499. # Shuffling the data artificially breaks the dependency and hides the
  500. # overfitting of the model with regards to the writing style of the authors
  501. # by yielding a seriously overestimated score:
  502. cv = KFold(n_splits, shuffle=True, random_state=0)
  503. mean_score = cross_val_score(model, X, y, cv=cv).mean()
  504. assert mean_score > 0.92
  505. cv = KFold(n_splits, shuffle=True, random_state=1)
  506. mean_score = cross_val_score(model, X, y, cv=cv).mean()
  507. assert mean_score > 0.92
  508. # Similarly, StratifiedKFold should try to shuffle the data as little
  509. # as possible (while respecting the balanced class constraints)
  510. # and thus be able to detect the dependency by not overestimating
  511. # the CV score either. As the digits dataset is approximately balanced
  512. # the estimated mean score is close to the score measured with
  513. # non-shuffled KFold
  514. cv = StratifiedKFold(n_splits)
  515. mean_score = cross_val_score(model, X, y, cv=cv).mean()
  516. assert 0.94 > mean_score
  517. assert mean_score > 0.80
  518. def test_stratified_group_kfold_trivial():
  519. sgkf = StratifiedGroupKFold(n_splits=3)
  520. # Trivial example - groups with the same distribution
  521. y = np.array([1] * 6 + [0] * 12)
  522. X = np.ones_like(y).reshape(-1, 1)
  523. groups = np.asarray((1, 2, 3, 4, 5, 6, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6))
  524. distr = np.bincount(y) / len(y)
  525. test_sizes = []
  526. for train, test in sgkf.split(X, y, groups):
  527. # check group constraint
  528. assert np.intersect1d(groups[train], groups[test]).size == 0
  529. # check y distribution
  530. assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
  531. assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
  532. test_sizes.append(len(test))
  533. assert np.ptp(test_sizes) <= 1
  534. def test_stratified_group_kfold_approximate():
  535. # Not perfect stratification (even though it is possible) because of
  536. # iteration over groups
  537. sgkf = StratifiedGroupKFold(n_splits=3)
  538. y = np.array([1] * 6 + [0] * 12)
  539. X = np.ones_like(y).reshape(-1, 1)
  540. groups = np.array([1, 2, 3, 3, 4, 4, 1, 1, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6])
  541. expected = np.asarray([[0.833, 0.166], [0.666, 0.333], [0.5, 0.5]])
  542. test_sizes = []
  543. for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
  544. # check group constraint
  545. assert np.intersect1d(groups[train], groups[test]).size == 0
  546. split_dist = np.bincount(y[test]) / len(test)
  547. assert_allclose(split_dist, expect_dist, atol=0.001)
  548. test_sizes.append(len(test))
  549. assert np.ptp(test_sizes) <= 1
  550. @pytest.mark.parametrize(
  551. "y, groups, expected",
  552. [
  553. (
  554. np.array([0] * 6 + [1] * 6),
  555. np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
  556. np.asarray([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]),
  557. ),
  558. (
  559. np.array([0] * 9 + [1] * 3),
  560. np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6]),
  561. np.asarray([[0.75, 0.25], [0.75, 0.25], [0.75, 0.25]]),
  562. ),
  563. ],
  564. )
  565. def test_stratified_group_kfold_homogeneous_groups(y, groups, expected):
  566. sgkf = StratifiedGroupKFold(n_splits=3)
  567. X = np.ones_like(y).reshape(-1, 1)
  568. for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
  569. # check group constraint
  570. assert np.intersect1d(groups[train], groups[test]).size == 0
  571. split_dist = np.bincount(y[test]) / len(test)
  572. assert_allclose(split_dist, expect_dist, atol=0.001)
  573. @pytest.mark.parametrize("cls_distr", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), (0.8, 0.2)])
  574. @pytest.mark.parametrize("n_groups", [5, 30, 70])
  575. def test_stratified_group_kfold_against_group_kfold(cls_distr, n_groups):
  576. # Check that given sufficient amount of samples StratifiedGroupKFold
  577. # produces better stratified folds than regular GroupKFold
  578. n_splits = 5
  579. sgkf = StratifiedGroupKFold(n_splits=n_splits)
  580. gkf = GroupKFold(n_splits=n_splits)
  581. rng = np.random.RandomState(0)
  582. n_points = 1000
  583. y = rng.choice(2, size=n_points, p=cls_distr)
  584. X = np.ones_like(y).reshape(-1, 1)
  585. g = rng.choice(n_groups, n_points)
  586. sgkf_folds = sgkf.split(X, y, groups=g)
  587. gkf_folds = gkf.split(X, y, groups=g)
  588. sgkf_entr = 0
  589. gkf_entr = 0
  590. for (sgkf_train, sgkf_test), (_, gkf_test) in zip(sgkf_folds, gkf_folds):
  591. # check group constraint
  592. assert np.intersect1d(g[sgkf_train], g[sgkf_test]).size == 0
  593. sgkf_distr = np.bincount(y[sgkf_test]) / len(sgkf_test)
  594. gkf_distr = np.bincount(y[gkf_test]) / len(gkf_test)
  595. sgkf_entr += stats.entropy(sgkf_distr, qk=cls_distr)
  596. gkf_entr += stats.entropy(gkf_distr, qk=cls_distr)
  597. sgkf_entr /= n_splits
  598. gkf_entr /= n_splits
  599. assert sgkf_entr <= gkf_entr
  600. def test_shuffle_split():
  601. ss1 = ShuffleSplit(test_size=0.2, random_state=0).split(X)
  602. ss2 = ShuffleSplit(test_size=2, random_state=0).split(X)
  603. ss3 = ShuffleSplit(test_size=np.int32(2), random_state=0).split(X)
  604. ss4 = ShuffleSplit(test_size=int(2), random_state=0).split(X)
  605. for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4):
  606. assert_array_equal(t1[0], t2[0])
  607. assert_array_equal(t2[0], t3[0])
  608. assert_array_equal(t3[0], t4[0])
  609. assert_array_equal(t1[1], t2[1])
  610. assert_array_equal(t2[1], t3[1])
  611. assert_array_equal(t3[1], t4[1])
  612. @pytest.mark.parametrize("split_class", [ShuffleSplit, StratifiedShuffleSplit])
  613. @pytest.mark.parametrize(
  614. "train_size, exp_train, exp_test", [(None, 9, 1), (8, 8, 2), (0.8, 8, 2)]
  615. )
  616. def test_shuffle_split_default_test_size(split_class, train_size, exp_train, exp_test):
  617. # Check that the default value has the expected behavior, i.e. 0.1 if both
  618. # unspecified or complement train_size unless both are specified.
  619. X = np.ones(10)
  620. y = np.ones(10)
  621. X_train, X_test = next(split_class(train_size=train_size).split(X, y))
  622. assert len(X_train) == exp_train
  623. assert len(X_test) == exp_test
  624. @pytest.mark.parametrize(
  625. "train_size, exp_train, exp_test", [(None, 8, 2), (7, 7, 3), (0.7, 7, 3)]
  626. )
  627. def test_group_shuffle_split_default_test_size(train_size, exp_train, exp_test):
  628. # Check that the default value has the expected behavior, i.e. 0.2 if both
  629. # unspecified or complement train_size unless both are specified.
  630. X = np.ones(10)
  631. y = np.ones(10)
  632. groups = range(10)
  633. X_train, X_test = next(GroupShuffleSplit(train_size=train_size).split(X, y, groups))
  634. assert len(X_train) == exp_train
  635. assert len(X_test) == exp_test
  636. @ignore_warnings
  637. def test_stratified_shuffle_split_init():
  638. X = np.arange(7)
  639. y = np.asarray([0, 1, 1, 1, 2, 2, 2])
  640. # Check that error is raised if there is a class with only one sample
  641. with pytest.raises(ValueError):
  642. next(StratifiedShuffleSplit(3, test_size=0.2).split(X, y))
  643. # Check that error is raised if the test set size is smaller than n_classes
  644. with pytest.raises(ValueError):
  645. next(StratifiedShuffleSplit(3, test_size=2).split(X, y))
  646. # Check that error is raised if the train set size is smaller than
  647. # n_classes
  648. with pytest.raises(ValueError):
  649. next(StratifiedShuffleSplit(3, test_size=3, train_size=2).split(X, y))
  650. X = np.arange(9)
  651. y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])
  652. # Train size or test size too small
  653. with pytest.raises(ValueError):
  654. next(StratifiedShuffleSplit(train_size=2).split(X, y))
  655. with pytest.raises(ValueError):
  656. next(StratifiedShuffleSplit(test_size=2).split(X, y))
  657. def test_stratified_shuffle_split_respects_test_size():
  658. y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
  659. test_size = 5
  660. train_size = 10
  661. sss = StratifiedShuffleSplit(
  662. 6, test_size=test_size, train_size=train_size, random_state=0
  663. ).split(np.ones(len(y)), y)
  664. for train, test in sss:
  665. assert len(train) == train_size
  666. assert len(test) == test_size
  667. def test_stratified_shuffle_split_iter():
  668. ys = [
  669. np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
  670. np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
  671. np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
  672. np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
  673. np.array([-1] * 800 + [1] * 50),
  674. np.concatenate([[i] * (100 + i) for i in range(11)]),
  675. [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
  676. ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
  677. ]
  678. for y in ys:
  679. sss = StratifiedShuffleSplit(6, test_size=0.33, random_state=0).split(
  680. np.ones(len(y)), y
  681. )
  682. y = np.asanyarray(y) # To make it indexable for y[train]
  683. # this is how test-size is computed internally
  684. # in _validate_shuffle_split
  685. test_size = np.ceil(0.33 * len(y))
  686. train_size = len(y) - test_size
  687. for train, test in sss:
  688. assert_array_equal(np.unique(y[train]), np.unique(y[test]))
  689. # Checks if folds keep classes proportions
  690. p_train = np.bincount(np.unique(y[train], return_inverse=True)[1]) / float(
  691. len(y[train])
  692. )
  693. p_test = np.bincount(np.unique(y[test], return_inverse=True)[1]) / float(
  694. len(y[test])
  695. )
  696. assert_array_almost_equal(p_train, p_test, 1)
  697. assert len(train) + len(test) == y.size
  698. assert len(train) == train_size
  699. assert len(test) == test_size
  700. assert_array_equal(np.intersect1d(train, test), [])
  701. def test_stratified_shuffle_split_even():
  702. # Test the StratifiedShuffleSplit, indices are drawn with a
  703. # equal chance
  704. n_folds = 5
  705. n_splits = 1000
  706. def assert_counts_are_ok(idx_counts, p):
  707. # Here we test that the distribution of the counts
  708. # per index is close enough to a binomial
  709. threshold = 0.05 / n_splits
  710. bf = stats.binom(n_splits, p)
  711. for count in idx_counts:
  712. prob = bf.pmf(count)
  713. assert (
  714. prob > threshold
  715. ), "An index is not drawn with chance corresponding to even draws"
  716. for n_samples in (6, 22):
  717. groups = np.array((n_samples // 2) * [0, 1])
  718. splits = StratifiedShuffleSplit(
  719. n_splits=n_splits, test_size=1.0 / n_folds, random_state=0
  720. )
  721. train_counts = [0] * n_samples
  722. test_counts = [0] * n_samples
  723. n_splits_actual = 0
  724. for train, test in splits.split(X=np.ones(n_samples), y=groups):
  725. n_splits_actual += 1
  726. for counter, ids in [(train_counts, train), (test_counts, test)]:
  727. for id in ids:
  728. counter[id] += 1
  729. assert n_splits_actual == n_splits
  730. n_train, n_test = _validate_shuffle_split(
  731. n_samples, test_size=1.0 / n_folds, train_size=1.0 - (1.0 / n_folds)
  732. )
  733. assert len(train) == n_train
  734. assert len(test) == n_test
  735. assert len(set(train).intersection(test)) == 0
  736. group_counts = np.unique(groups)
  737. assert splits.test_size == 1.0 / n_folds
  738. assert n_train + n_test == len(groups)
  739. assert len(group_counts) == 2
  740. ex_test_p = float(n_test) / n_samples
  741. ex_train_p = float(n_train) / n_samples
  742. assert_counts_are_ok(train_counts, ex_train_p)
  743. assert_counts_are_ok(test_counts, ex_test_p)
  744. def test_stratified_shuffle_split_overlap_train_test_bug():
  745. # See https://github.com/scikit-learn/scikit-learn/issues/6121 for
  746. # the original bug report
  747. y = [0, 1, 2, 3] * 3 + [4, 5] * 5
  748. X = np.ones_like(y)
  749. sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
  750. train, test = next(sss.split(X=X, y=y))
  751. # no overlap
  752. assert_array_equal(np.intersect1d(train, test), [])
  753. # complete partition
  754. assert_array_equal(np.union1d(train, test), np.arange(len(y)))
  755. def test_stratified_shuffle_split_multilabel():
  756. # fix for issue 9037
  757. for y in [
  758. np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
  759. np.array([[0, 1], [1, 1], [1, 1], [0, 1]]),
  760. ]:
  761. X = np.ones_like(y)
  762. sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
  763. train, test = next(sss.split(X=X, y=y))
  764. y_train = y[train]
  765. y_test = y[test]
  766. # no overlap
  767. assert_array_equal(np.intersect1d(train, test), [])
  768. # complete partition
  769. assert_array_equal(np.union1d(train, test), np.arange(len(y)))
  770. # correct stratification of entire rows
  771. # (by design, here y[:, 0] uniquely determines the entire row of y)
  772. expected_ratio = np.mean(y[:, 0])
  773. assert expected_ratio == np.mean(y_train[:, 0])
  774. assert expected_ratio == np.mean(y_test[:, 0])
  775. def test_stratified_shuffle_split_multilabel_many_labels():
  776. # fix in PR #9922: for multilabel data with > 1000 labels, str(row)
  777. # truncates with an ellipsis for elements in positions 4 through
  778. # len(row) - 4, so labels were not being correctly split using the powerset
  779. # method for transforming a multilabel problem to a multiclass one; this
  780. # test checks that this problem is fixed.
  781. row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
  782. row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
  783. y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
  784. X = np.ones_like(y)
  785. sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
  786. train, test = next(sss.split(X=X, y=y))
  787. y_train = y[train]
  788. y_test = y[test]
  789. # correct stratification of entire rows
  790. # (by design, here y[:, 4] uniquely determines the entire row of y)
  791. expected_ratio = np.mean(y[:, 4])
  792. assert expected_ratio == np.mean(y_train[:, 4])
  793. assert expected_ratio == np.mean(y_test[:, 4])
  794. def test_predefinedsplit_with_kfold_split():
  795. # Check that PredefinedSplit can reproduce a split generated by Kfold.
  796. folds = np.full(10, -1.0)
  797. kf_train = []
  798. kf_test = []
  799. for i, (train_ind, test_ind) in enumerate(KFold(5, shuffle=True).split(X)):
  800. kf_train.append(train_ind)
  801. kf_test.append(test_ind)
  802. folds[test_ind] = i
  803. ps = PredefinedSplit(folds)
  804. # n_splits is simply the no of unique folds
  805. assert len(np.unique(folds)) == ps.get_n_splits()
  806. ps_train, ps_test = zip(*ps.split())
  807. assert_array_equal(ps_train, kf_train)
  808. assert_array_equal(ps_test, kf_test)
  809. def test_group_shuffle_split():
  810. for groups_i in test_groups:
  811. X = y = np.ones(len(groups_i))
  812. n_splits = 6
  813. test_size = 1.0 / 3
  814. slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0)
  815. # Make sure the repr works
  816. repr(slo)
  817. # Test that the length is correct
  818. assert slo.get_n_splits(X, y, groups=groups_i) == n_splits
  819. l_unique = np.unique(groups_i)
  820. l = np.asarray(groups_i)
  821. for train, test in slo.split(X, y, groups=groups_i):
  822. # First test: no train group is in the test set and vice versa
  823. l_train_unique = np.unique(l[train])
  824. l_test_unique = np.unique(l[test])
  825. assert not np.any(np.isin(l[train], l_test_unique))
  826. assert not np.any(np.isin(l[test], l_train_unique))
  827. # Second test: train and test add up to all the data
  828. assert l[train].size + l[test].size == l.size
  829. # Third test: train and test are disjoint
  830. assert_array_equal(np.intersect1d(train, test), [])
  831. # Fourth test:
  832. # unique train and test groups are correct, +- 1 for rounding error
  833. assert abs(len(l_test_unique) - round(test_size * len(l_unique))) <= 1
  834. assert (
  835. abs(len(l_train_unique) - round((1.0 - test_size) * len(l_unique))) <= 1
  836. )
  837. def test_leave_one_p_group_out():
  838. logo = LeaveOneGroupOut()
  839. lpgo_1 = LeavePGroupsOut(n_groups=1)
  840. lpgo_2 = LeavePGroupsOut(n_groups=2)
  841. # Make sure the repr works
  842. assert repr(logo) == "LeaveOneGroupOut()"
  843. assert repr(lpgo_1) == "LeavePGroupsOut(n_groups=1)"
  844. assert repr(lpgo_2) == "LeavePGroupsOut(n_groups=2)"
  845. assert repr(LeavePGroupsOut(n_groups=3)) == "LeavePGroupsOut(n_groups=3)"
  846. for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), (lpgo_2, 2))):
  847. for i, groups_i in enumerate(test_groups):
  848. n_groups = len(np.unique(groups_i))
  849. n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2
  850. X = y = np.ones(len(groups_i))
  851. # Test that the length is correct
  852. assert cv.get_n_splits(X, y, groups=groups_i) == n_splits
  853. groups_arr = np.asarray(groups_i)
  854. # Split using the original list / array / list of string groups_i
  855. for train, test in cv.split(X, y, groups=groups_i):
  856. # First test: no train group is in the test set and vice versa
  857. assert_array_equal(
  858. np.intersect1d(groups_arr[train], groups_arr[test]).tolist(), []
  859. )
  860. # Second test: train and test add up to all the data
  861. assert len(train) + len(test) == len(groups_i)
  862. # Third test:
  863. # The number of groups in test must be equal to p_groups_out
  864. assert np.unique(groups_arr[test]).shape[0], p_groups_out
  865. # check get_n_splits() with dummy parameters
  866. assert logo.get_n_splits(None, None, ["a", "b", "c", "b", "c"]) == 3
  867. assert logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]) == 3
  868. assert lpgo_2.get_n_splits(None, None, np.arange(4)) == 6
  869. assert lpgo_1.get_n_splits(groups=np.arange(4)) == 4
  870. # raise ValueError if a `groups` parameter is illegal
  871. with pytest.raises(ValueError):
  872. logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
  873. with pytest.raises(ValueError):
  874. lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])
  875. msg = "The 'groups' parameter should not be None."
  876. with pytest.raises(ValueError, match=msg):
  877. logo.get_n_splits(None, None, None)
  878. with pytest.raises(ValueError, match=msg):
  879. lpgo_1.get_n_splits(None, None, None)
  880. def test_leave_group_out_changing_groups():
  881. # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
  882. # the groups variable is changed before calling split
  883. groups = np.array([0, 1, 2, 1, 1, 2, 0, 0])
  884. X = np.ones(len(groups))
  885. groups_changing = np.array(groups, copy=True)
  886. lolo = LeaveOneGroupOut().split(X, groups=groups)
  887. lolo_changing = LeaveOneGroupOut().split(X, groups=groups)
  888. lplo = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
  889. lplo_changing = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
  890. groups_changing[:] = 0
  891. for llo, llo_changing in [(lolo, lolo_changing), (lplo, lplo_changing)]:
  892. for (train, test), (train_chan, test_chan) in zip(llo, llo_changing):
  893. assert_array_equal(train, train_chan)
  894. assert_array_equal(test, test_chan)
  895. # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3
  896. assert 3 == LeavePGroupsOut(n_groups=2).get_n_splits(X, y=X, groups=groups)
  897. # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups)
  898. assert 3 == LeaveOneGroupOut().get_n_splits(X, y=X, groups=groups)
  899. def test_leave_group_out_order_dependence():
  900. # Check that LeaveOneGroupOut orders the splits according to the index
  901. # of the group left out.
  902. groups = np.array([2, 2, 0, 0, 1, 1])
  903. X = np.ones(len(groups))
  904. splits = iter(LeaveOneGroupOut().split(X, groups=groups))
  905. expected_indices = [
  906. ([0, 1, 4, 5], [2, 3]),
  907. ([0, 1, 2, 3], [4, 5]),
  908. ([2, 3, 4, 5], [0, 1]),
  909. ]
  910. for expected_train, expected_test in expected_indices:
  911. train, test = next(splits)
  912. assert_array_equal(train, expected_train)
  913. assert_array_equal(test, expected_test)
  914. def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
  915. X = y = groups = np.ones(0)
  916. msg = re.escape("Found array with 0 sample(s)")
  917. with pytest.raises(ValueError, match=msg):
  918. next(LeaveOneGroupOut().split(X, y, groups))
  919. X = y = groups = np.ones(1)
  920. msg = re.escape(
  921. f"The groups parameter contains fewer than 2 unique groups ({groups})."
  922. " LeaveOneGroupOut expects at least 2."
  923. )
  924. with pytest.raises(ValueError, match=msg):
  925. next(LeaveOneGroupOut().split(X, y, groups))
  926. X = y = groups = np.ones(1)
  927. msg = re.escape(
  928. "The groups parameter contains fewer than (or equal to) n_groups "
  929. f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
  930. "that at least n_groups + 1 (4) unique groups "
  931. "be present"
  932. )
  933. with pytest.raises(ValueError, match=msg):
  934. next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
  935. X = y = groups = np.arange(3)
  936. msg = re.escape(
  937. "The groups parameter contains fewer than (or equal to) n_groups "
  938. f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
  939. "that at least n_groups + 1 (4) unique groups "
  940. "be present"
  941. )
  942. with pytest.raises(ValueError, match=msg):
  943. next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
  944. @ignore_warnings
  945. def test_repeated_cv_value_errors():
  946. # n_repeats is not integer or <= 0
  947. for cv in (RepeatedKFold, RepeatedStratifiedKFold):
  948. with pytest.raises(ValueError):
  949. cv(n_repeats=0)
  950. with pytest.raises(ValueError):
  951. cv(n_repeats=1.5)
  952. @pytest.mark.parametrize("RepeatedCV", [RepeatedKFold, RepeatedStratifiedKFold])
  953. def test_repeated_cv_repr(RepeatedCV):
  954. n_splits, n_repeats = 2, 6
  955. repeated_cv = RepeatedCV(n_splits=n_splits, n_repeats=n_repeats)
  956. repeated_cv_repr = "{}(n_repeats=6, n_splits=2, random_state=None)".format(
  957. repeated_cv.__class__.__name__
  958. )
  959. assert repeated_cv_repr == repr(repeated_cv)
  960. def test_repeated_kfold_determinstic_split():
  961. X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
  962. random_state = 258173307
  963. rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
  964. # split should produce same and deterministic splits on
  965. # each call
  966. for _ in range(3):
  967. splits = rkf.split(X)
  968. train, test = next(splits)
  969. assert_array_equal(train, [2, 4])
  970. assert_array_equal(test, [0, 1, 3])
  971. train, test = next(splits)
  972. assert_array_equal(train, [0, 1, 3])
  973. assert_array_equal(test, [2, 4])
  974. train, test = next(splits)
  975. assert_array_equal(train, [0, 1])
  976. assert_array_equal(test, [2, 3, 4])
  977. train, test = next(splits)
  978. assert_array_equal(train, [2, 3, 4])
  979. assert_array_equal(test, [0, 1])
  980. with pytest.raises(StopIteration):
  981. next(splits)
  982. def test_get_n_splits_for_repeated_kfold():
  983. n_splits = 3
  984. n_repeats = 4
  985. rkf = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats)
  986. expected_n_splits = n_splits * n_repeats
  987. assert expected_n_splits == rkf.get_n_splits()
  988. def test_get_n_splits_for_repeated_stratified_kfold():
  989. n_splits = 3
  990. n_repeats = 4
  991. rskf = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats)
  992. expected_n_splits = n_splits * n_repeats
  993. assert expected_n_splits == rskf.get_n_splits()
  994. def test_repeated_stratified_kfold_determinstic_split():
  995. X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
  996. y = [1, 1, 1, 0, 0]
  997. random_state = 1944695409
  998. rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, random_state=random_state)
  999. # split should produce same and deterministic splits on
  1000. # each call
  1001. for _ in range(3):
  1002. splits = rskf.split(X, y)
  1003. train, test = next(splits)
  1004. assert_array_equal(train, [1, 4])
  1005. assert_array_equal(test, [0, 2, 3])
  1006. train, test = next(splits)
  1007. assert_array_equal(train, [0, 2, 3])
  1008. assert_array_equal(test, [1, 4])
  1009. train, test = next(splits)
  1010. assert_array_equal(train, [2, 3])
  1011. assert_array_equal(test, [0, 1, 4])
  1012. train, test = next(splits)
  1013. assert_array_equal(train, [0, 1, 4])
  1014. assert_array_equal(test, [2, 3])
  1015. with pytest.raises(StopIteration):
  1016. next(splits)
  1017. def test_train_test_split_errors():
  1018. pytest.raises(ValueError, train_test_split)
  1019. pytest.raises(ValueError, train_test_split, range(3), train_size=1.1)
  1020. pytest.raises(ValueError, train_test_split, range(3), test_size=0.6, train_size=0.6)
  1021. pytest.raises(
  1022. ValueError,
  1023. train_test_split,
  1024. range(3),
  1025. test_size=np.float32(0.6),
  1026. train_size=np.float32(0.6),
  1027. )
  1028. pytest.raises(ValueError, train_test_split, range(3), test_size="wrong_type")
  1029. pytest.raises(ValueError, train_test_split, range(3), test_size=2, train_size=4)
  1030. pytest.raises(TypeError, train_test_split, range(3), some_argument=1.1)
  1031. pytest.raises(ValueError, train_test_split, range(3), range(42))
  1032. pytest.raises(ValueError, train_test_split, range(10), shuffle=False, stratify=True)
  1033. with pytest.raises(
  1034. ValueError,
  1035. match=r"train_size=11 should be either positive and "
  1036. r"smaller than the number of samples 10 or a "
  1037. r"float in the \(0, 1\) range",
  1038. ):
  1039. train_test_split(range(10), train_size=11, test_size=1)
  1040. @pytest.mark.parametrize(
  1041. "train_size, exp_train, exp_test", [(None, 7, 3), (8, 8, 2), (0.8, 8, 2)]
  1042. )
  1043. def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
  1044. # Check that the default value has the expected behavior, i.e. complement
  1045. # train_size unless both are specified.
  1046. X_train, X_test = train_test_split(X, train_size=train_size)
  1047. assert len(X_train) == exp_train
  1048. assert len(X_test) == exp_test
  1049. def test_train_test_split():
  1050. X = np.arange(100).reshape((10, 10))
  1051. X_s = coo_matrix(X)
  1052. y = np.arange(10)
  1053. # simple test
  1054. split = train_test_split(X, y, test_size=None, train_size=0.5)
  1055. X_train, X_test, y_train, y_test = split
  1056. assert len(y_test) == len(y_train)
  1057. # test correspondence of X and y
  1058. assert_array_equal(X_train[:, 0], y_train * 10)
  1059. assert_array_equal(X_test[:, 0], y_test * 10)
  1060. # don't convert lists to anything else by default
  1061. split = train_test_split(X, X_s, y.tolist())
  1062. X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
  1063. assert isinstance(y_train, list)
  1064. assert isinstance(y_test, list)
  1065. # allow nd-arrays
  1066. X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
  1067. y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
  1068. split = train_test_split(X_4d, y_3d)
  1069. assert split[0].shape == (7, 5, 3, 2)
  1070. assert split[1].shape == (3, 5, 3, 2)
  1071. assert split[2].shape == (7, 7, 11)
  1072. assert split[3].shape == (3, 7, 11)
  1073. # test stratification option
  1074. y = np.array([1, 1, 1, 1, 2, 2, 2, 2])
  1075. for test_size, exp_test_size in zip([2, 4, 0.25, 0.5, 0.75], [2, 4, 2, 4, 6]):
  1076. train, test = train_test_split(
  1077. y, test_size=test_size, stratify=y, random_state=0
  1078. )
  1079. assert len(test) == exp_test_size
  1080. assert len(test) + len(train) == len(y)
  1081. # check the 1:1 ratio of ones and twos in the data is preserved
  1082. assert np.sum(train == 1) == np.sum(train == 2)
  1083. # test unshuffled split
  1084. y = np.arange(10)
  1085. for test_size in [2, 0.2]:
  1086. train, test = train_test_split(y, shuffle=False, test_size=test_size)
  1087. assert_array_equal(test, [8, 9])
  1088. assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])
  1089. def test_train_test_split_32bit_overflow():
  1090. """Check for integer overflow on 32-bit platforms.
  1091. Non-regression test for:
  1092. https://github.com/scikit-learn/scikit-learn/issues/20774
  1093. """
  1094. # A number 'n' big enough for expression 'n * n * train_size' to cause
  1095. # an overflow for signed 32-bit integer
  1096. big_number = 100000
  1097. # Definition of 'y' is a part of reproduction - population for at least
  1098. # one class should be in the same order of magnitude as size of X
  1099. X = np.arange(big_number)
  1100. y = X > (0.99 * big_number)
  1101. split = train_test_split(X, y, stratify=y, train_size=0.25)
  1102. X_train, X_test, y_train, y_test = split
  1103. assert X_train.size + X_test.size == big_number
  1104. assert y_train.size + y_test.size == big_number
  1105. @ignore_warnings
  1106. def test_train_test_split_pandas():
  1107. # check train_test_split doesn't destroy pandas dataframe
  1108. types = [MockDataFrame]
  1109. try:
  1110. from pandas import DataFrame
  1111. types.append(DataFrame)
  1112. except ImportError:
  1113. pass
  1114. for InputFeatureType in types:
  1115. # X dataframe
  1116. X_df = InputFeatureType(X)
  1117. X_train, X_test = train_test_split(X_df)
  1118. assert isinstance(X_train, InputFeatureType)
  1119. assert isinstance(X_test, InputFeatureType)
  1120. def test_train_test_split_sparse():
  1121. # check that train_test_split converts scipy sparse matrices
  1122. # to csr, as stated in the documentation
  1123. X = np.arange(100).reshape((10, 10))
  1124. sparse_types = [csr_matrix, csc_matrix, coo_matrix]
  1125. for InputFeatureType in sparse_types:
  1126. X_s = InputFeatureType(X)
  1127. X_train, X_test = train_test_split(X_s)
  1128. assert issparse(X_train) and X_train.format == "csr"
  1129. assert issparse(X_test) and X_test.format == "csr"
  1130. def test_train_test_split_mock_pandas():
  1131. # X mock dataframe
  1132. X_df = MockDataFrame(X)
  1133. X_train, X_test = train_test_split(X_df)
  1134. assert isinstance(X_train, MockDataFrame)
  1135. assert isinstance(X_test, MockDataFrame)
  1136. X_train_arr, X_test_arr = train_test_split(X_df)
  1137. def test_train_test_split_list_input():
  1138. # Check that when y is a list / list of string labels, it works.
  1139. X = np.ones(7)
  1140. y1 = ["1"] * 4 + ["0"] * 3
  1141. y2 = np.hstack((np.ones(4), np.zeros(3)))
  1142. y3 = y2.tolist()
  1143. for stratify in (True, False):
  1144. X_train1, X_test1, y_train1, y_test1 = train_test_split(
  1145. X, y1, stratify=y1 if stratify else None, random_state=0
  1146. )
  1147. X_train2, X_test2, y_train2, y_test2 = train_test_split(
  1148. X, y2, stratify=y2 if stratify else None, random_state=0
  1149. )
  1150. X_train3, X_test3, y_train3, y_test3 = train_test_split(
  1151. X, y3, stratify=y3 if stratify else None, random_state=0
  1152. )
  1153. np.testing.assert_equal(X_train1, X_train2)
  1154. np.testing.assert_equal(y_train2, y_train3)
  1155. np.testing.assert_equal(X_test1, X_test3)
  1156. np.testing.assert_equal(y_test3, y_test2)
  1157. @pytest.mark.parametrize(
  1158. "test_size, train_size",
  1159. [(2.0, None), (1.0, None), (0.1, 0.95), (None, 1j), (11, None), (10, None), (8, 3)],
  1160. )
  1161. def test_shufflesplit_errors(test_size, train_size):
  1162. with pytest.raises(ValueError):
  1163. next(ShuffleSplit(test_size=test_size, train_size=train_size).split(X))
  1164. def test_shufflesplit_reproducible():
  1165. # Check that iterating twice on the ShuffleSplit gives the same
  1166. # sequence of train-test when the random_state is given
  1167. ss = ShuffleSplit(random_state=21)
  1168. assert_array_equal([a for a, b in ss.split(X)], [a for a, b in ss.split(X)])
  1169. def test_stratifiedshufflesplit_list_input():
  1170. # Check that when y is a list / list of string labels, it works.
  1171. sss = StratifiedShuffleSplit(test_size=2, random_state=42)
  1172. X = np.ones(7)
  1173. y1 = ["1"] * 4 + ["0"] * 3
  1174. y2 = np.hstack((np.ones(4), np.zeros(3)))
  1175. y3 = y2.tolist()
  1176. np.testing.assert_equal(list(sss.split(X, y1)), list(sss.split(X, y2)))
  1177. np.testing.assert_equal(list(sss.split(X, y3)), list(sss.split(X, y2)))
  1178. def test_train_test_split_allow_nans():
  1179. # Check that train_test_split allows input data with NaNs
  1180. X = np.arange(200, dtype=np.float64).reshape(10, -1)
  1181. X[2, :] = np.nan
  1182. y = np.repeat([0, 1], X.shape[0] / 2)
  1183. train_test_split(X, y, test_size=0.2, random_state=42)
  1184. def test_check_cv():
  1185. X = np.ones(9)
  1186. cv = check_cv(3, classifier=False)
  1187. # Use numpy.testing.assert_equal which recursively compares
  1188. # lists of lists
  1189. np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
  1190. y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
  1191. cv = check_cv(3, y_binary, classifier=True)
  1192. np.testing.assert_equal(
  1193. list(StratifiedKFold(3).split(X, y_binary)), list(cv.split(X, y_binary))
  1194. )
  1195. y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
  1196. cv = check_cv(3, y_multiclass, classifier=True)
  1197. np.testing.assert_equal(
  1198. list(StratifiedKFold(3).split(X, y_multiclass)), list(cv.split(X, y_multiclass))
  1199. )
  1200. # also works with 2d multiclass
  1201. y_multiclass_2d = y_multiclass.reshape(-1, 1)
  1202. cv = check_cv(3, y_multiclass_2d, classifier=True)
  1203. np.testing.assert_equal(
  1204. list(StratifiedKFold(3).split(X, y_multiclass_2d)),
  1205. list(cv.split(X, y_multiclass_2d)),
  1206. )
  1207. assert not np.all(
  1208. next(StratifiedKFold(3).split(X, y_multiclass_2d))[0]
  1209. == next(KFold(3).split(X, y_multiclass_2d))[0]
  1210. )
  1211. X = np.ones(5)
  1212. y_multilabel = np.array(
  1213. [[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1], [1, 1, 0, 1], [0, 0, 1, 0]]
  1214. )
  1215. cv = check_cv(3, y_multilabel, classifier=True)
  1216. np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
  1217. y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
  1218. cv = check_cv(3, y_multioutput, classifier=True)
  1219. np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
  1220. with pytest.raises(ValueError):
  1221. check_cv(cv="lolo")
  1222. def test_cv_iterable_wrapper():
  1223. kf_iter = KFold().split(X, y)
  1224. kf_iter_wrapped = check_cv(kf_iter)
  1225. # Since the wrapped iterable is enlisted and stored,
  1226. # split can be called any number of times to produce
  1227. # consistent results.
  1228. np.testing.assert_equal(
  1229. list(kf_iter_wrapped.split(X, y)), list(kf_iter_wrapped.split(X, y))
  1230. )
  1231. # If the splits are randomized, successive calls to split yields different
  1232. # results
  1233. kf_randomized_iter = KFold(shuffle=True, random_state=0).split(X, y)
  1234. kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
  1235. # numpy's assert_array_equal properly compares nested lists
  1236. np.testing.assert_equal(
  1237. list(kf_randomized_iter_wrapped.split(X, y)),
  1238. list(kf_randomized_iter_wrapped.split(X, y)),
  1239. )
  1240. try:
  1241. splits_are_equal = True
  1242. np.testing.assert_equal(
  1243. list(kf_iter_wrapped.split(X, y)),
  1244. list(kf_randomized_iter_wrapped.split(X, y)),
  1245. )
  1246. except AssertionError:
  1247. splits_are_equal = False
  1248. assert not splits_are_equal, (
  1249. "If the splits are randomized, "
  1250. "successive calls to split should yield different results"
  1251. )
  1252. @pytest.mark.parametrize("kfold", [GroupKFold, StratifiedGroupKFold])
  1253. def test_group_kfold(kfold):
  1254. rng = np.random.RandomState(0)
  1255. # Parameters of the test
  1256. n_groups = 15
  1257. n_samples = 1000
  1258. n_splits = 5
  1259. X = y = np.ones(n_samples)
  1260. # Construct the test data
  1261. tolerance = 0.05 * n_samples # 5 percent error allowed
  1262. groups = rng.randint(0, n_groups, n_samples)
  1263. ideal_n_groups_per_fold = n_samples // n_splits
  1264. len(np.unique(groups))
  1265. # Get the test fold indices from the test set indices of each fold
  1266. folds = np.zeros(n_samples)
  1267. lkf = kfold(n_splits=n_splits)
  1268. for i, (_, test) in enumerate(lkf.split(X, y, groups)):
  1269. folds[test] = i
  1270. # Check that folds have approximately the same size
  1271. assert len(folds) == len(groups)
  1272. for i in np.unique(folds):
  1273. assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
  1274. # Check that each group appears only in 1 fold
  1275. for group in np.unique(groups):
  1276. assert len(np.unique(folds[groups == group])) == 1
  1277. # Check that no group is on both sides of the split
  1278. groups = np.asarray(groups, dtype=object)
  1279. for train, test in lkf.split(X, y, groups):
  1280. assert len(np.intersect1d(groups[train], groups[test])) == 0
  1281. # Construct the test data
  1282. groups = np.array(
  1283. [
  1284. "Albert",
  1285. "Jean",
  1286. "Bertrand",
  1287. "Michel",
  1288. "Jean",
  1289. "Francis",
  1290. "Robert",
  1291. "Michel",
  1292. "Rachel",
  1293. "Lois",
  1294. "Michelle",
  1295. "Bernard",
  1296. "Marion",
  1297. "Laura",
  1298. "Jean",
  1299. "Rachel",
  1300. "Franck",
  1301. "John",
  1302. "Gael",
  1303. "Anna",
  1304. "Alix",
  1305. "Robert",
  1306. "Marion",
  1307. "David",
  1308. "Tony",
  1309. "Abel",
  1310. "Becky",
  1311. "Madmood",
  1312. "Cary",
  1313. "Mary",
  1314. "Alexandre",
  1315. "David",
  1316. "Francis",
  1317. "Barack",
  1318. "Abdoul",
  1319. "Rasha",
  1320. "Xi",
  1321. "Silvia",
  1322. ]
  1323. )
  1324. n_groups = len(np.unique(groups))
  1325. n_samples = len(groups)
  1326. n_splits = 5
  1327. tolerance = 0.05 * n_samples # 5 percent error allowed
  1328. ideal_n_groups_per_fold = n_samples // n_splits
  1329. X = y = np.ones(n_samples)
  1330. # Get the test fold indices from the test set indices of each fold
  1331. folds = np.zeros(n_samples)
  1332. for i, (_, test) in enumerate(lkf.split(X, y, groups)):
  1333. folds[test] = i
  1334. # Check that folds have approximately the same size
  1335. assert len(folds) == len(groups)
  1336. for i in np.unique(folds):
  1337. assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
  1338. # Check that each group appears only in 1 fold
  1339. with warnings.catch_warnings():
  1340. warnings.simplefilter("ignore", FutureWarning)
  1341. for group in np.unique(groups):
  1342. assert len(np.unique(folds[groups == group])) == 1
  1343. # Check that no group is on both sides of the split
  1344. groups = np.asarray(groups, dtype=object)
  1345. for train, test in lkf.split(X, y, groups):
  1346. assert len(np.intersect1d(groups[train], groups[test])) == 0
  1347. # groups can also be a list
  1348. cv_iter = list(lkf.split(X, y, groups.tolist()))
  1349. for (train1, test1), (train2, test2) in zip(lkf.split(X, y, groups), cv_iter):
  1350. assert_array_equal(train1, train2)
  1351. assert_array_equal(test1, test2)
  1352. # Should fail if there are more folds than groups
  1353. groups = np.array([1, 1, 1, 2, 2])
  1354. X = y = np.ones(len(groups))
  1355. with pytest.raises(ValueError, match="Cannot have number of splits.*greater"):
  1356. next(GroupKFold(n_splits=3).split(X, y, groups))
  1357. def test_time_series_cv():
  1358. X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
  1359. # Should fail if there are more folds than samples
  1360. with pytest.raises(ValueError, match="Cannot have number of folds.*greater"):
  1361. next(TimeSeriesSplit(n_splits=7).split(X))
  1362. tscv = TimeSeriesSplit(2)
  1363. # Manually check that Time Series CV preserves the data
  1364. # ordering on toy datasets
  1365. splits = tscv.split(X[:-1])
  1366. train, test = next(splits)
  1367. assert_array_equal(train, [0, 1])
  1368. assert_array_equal(test, [2, 3])
  1369. train, test = next(splits)
  1370. assert_array_equal(train, [0, 1, 2, 3])
  1371. assert_array_equal(test, [4, 5])
  1372. splits = TimeSeriesSplit(2).split(X)
  1373. train, test = next(splits)
  1374. assert_array_equal(train, [0, 1, 2])
  1375. assert_array_equal(test, [3, 4])
  1376. train, test = next(splits)
  1377. assert_array_equal(train, [0, 1, 2, 3, 4])
  1378. assert_array_equal(test, [5, 6])
  1379. # Check get_n_splits returns the correct number of splits
  1380. splits = TimeSeriesSplit(2).split(X)
  1381. n_splits_actual = len(list(splits))
  1382. assert n_splits_actual == tscv.get_n_splits()
  1383. assert n_splits_actual == 2
  1384. def _check_time_series_max_train_size(splits, check_splits, max_train_size):
  1385. for (train, test), (check_train, check_test) in zip(splits, check_splits):
  1386. assert_array_equal(test, check_test)
  1387. assert len(check_train) <= max_train_size
  1388. suffix_start = max(len(train) - max_train_size, 0)
  1389. assert_array_equal(check_train, train[suffix_start:])
  1390. def test_time_series_max_train_size():
  1391. X = np.zeros((6, 1))
  1392. splits = TimeSeriesSplit(n_splits=3).split(X)
  1393. check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
  1394. _check_time_series_max_train_size(splits, check_splits, max_train_size=3)
  1395. # Test for the case where the size of a fold is greater than max_train_size
  1396. check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
  1397. _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
  1398. # Test for the case where the size of each fold is less than max_train_size
  1399. check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
  1400. _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
  1401. def test_time_series_test_size():
  1402. X = np.zeros((10, 1))
  1403. # Test alone
  1404. splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
  1405. train, test = next(splits)
  1406. assert_array_equal(train, [0])
  1407. assert_array_equal(test, [1, 2, 3])
  1408. train, test = next(splits)
  1409. assert_array_equal(train, [0, 1, 2, 3])
  1410. assert_array_equal(test, [4, 5, 6])
  1411. train, test = next(splits)
  1412. assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
  1413. assert_array_equal(test, [7, 8, 9])
  1414. # Test with max_train_size
  1415. splits = TimeSeriesSplit(n_splits=2, test_size=2, max_train_size=4).split(X)
  1416. train, test = next(splits)
  1417. assert_array_equal(train, [2, 3, 4, 5])
  1418. assert_array_equal(test, [6, 7])
  1419. train, test = next(splits)
  1420. assert_array_equal(train, [4, 5, 6, 7])
  1421. assert_array_equal(test, [8, 9])
  1422. # Should fail with not enough data points for configuration
  1423. with pytest.raises(ValueError, match="Too many splits.*with test_size"):
  1424. splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
  1425. next(splits)
  1426. def test_time_series_gap():
  1427. X = np.zeros((10, 1))
  1428. # Test alone
  1429. splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
  1430. train, test = next(splits)
  1431. assert_array_equal(train, [0, 1])
  1432. assert_array_equal(test, [4, 5, 6])
  1433. train, test = next(splits)
  1434. assert_array_equal(train, [0, 1, 2, 3, 4])
  1435. assert_array_equal(test, [7, 8, 9])
  1436. # Test with max_train_size
  1437. splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
  1438. train, test = next(splits)
  1439. assert_array_equal(train, [0, 1])
  1440. assert_array_equal(test, [4, 5])
  1441. train, test = next(splits)
  1442. assert_array_equal(train, [2, 3])
  1443. assert_array_equal(test, [6, 7])
  1444. train, test = next(splits)
  1445. assert_array_equal(train, [4, 5])
  1446. assert_array_equal(test, [8, 9])
  1447. # Test with test_size
  1448. splits = TimeSeriesSplit(n_splits=2, gap=2, max_train_size=4, test_size=2).split(X)
  1449. train, test = next(splits)
  1450. assert_array_equal(train, [0, 1, 2, 3])
  1451. assert_array_equal(test, [6, 7])
  1452. train, test = next(splits)
  1453. assert_array_equal(train, [2, 3, 4, 5])
  1454. assert_array_equal(test, [8, 9])
  1455. # Test with additional test_size
  1456. splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
  1457. train, test = next(splits)
  1458. assert_array_equal(train, [0, 1])
  1459. assert_array_equal(test, [4, 5, 6])
  1460. train, test = next(splits)
  1461. assert_array_equal(train, [0, 1, 2, 3, 4])
  1462. assert_array_equal(test, [7, 8, 9])
  1463. # Verify proper error is thrown
  1464. with pytest.raises(ValueError, match="Too many splits.*and gap"):
  1465. splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
  1466. next(splits)
  1467. def test_nested_cv():
  1468. # Test if nested cross validation works with different combinations of cv
  1469. rng = np.random.RandomState(0)
  1470. X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
  1471. groups = rng.randint(0, 5, 15)
  1472. cvs = [
  1473. LeaveOneGroupOut(),
  1474. StratifiedKFold(n_splits=2),
  1475. LeaveOneOut(),
  1476. GroupKFold(n_splits=3),
  1477. StratifiedKFold(),
  1478. StratifiedGroupKFold(),
  1479. StratifiedShuffleSplit(n_splits=3, random_state=0),
  1480. ]
  1481. for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
  1482. gs = GridSearchCV(
  1483. DummyClassifier(),
  1484. param_grid={"strategy": ["stratified", "most_frequent"]},
  1485. cv=inner_cv,
  1486. error_score="raise",
  1487. )
  1488. cross_val_score(
  1489. gs, X=X, y=y, groups=groups, cv=outer_cv, fit_params={"groups": groups}
  1490. )
  1491. def test_build_repr():
  1492. class MockSplitter:
  1493. def __init__(self, a, b=0, c=None):
  1494. self.a = a
  1495. self.b = b
  1496. self.c = c
  1497. def __repr__(self):
  1498. return _build_repr(self)
  1499. assert repr(MockSplitter(5, 6)) == "MockSplitter(a=5, b=6, c=None)"
  1500. @pytest.mark.parametrize(
  1501. "CVSplitter", (ShuffleSplit, GroupShuffleSplit, StratifiedShuffleSplit)
  1502. )
  1503. def test_shuffle_split_empty_trainset(CVSplitter):
  1504. cv = CVSplitter(test_size=0.99)
  1505. X, y = [[1]], [0] # 1 sample
  1506. with pytest.raises(
  1507. ValueError,
  1508. match=(
  1509. "With n_samples=1, test_size=0.99 and train_size=None, "
  1510. "the resulting train set will be empty"
  1511. ),
  1512. ):
  1513. next(cv.split(X, y, groups=[1]))
  1514. def test_train_test_split_empty_trainset():
  1515. (X,) = [[1]] # 1 sample
  1516. with pytest.raises(
  1517. ValueError,
  1518. match=(
  1519. "With n_samples=1, test_size=0.99 and train_size=None, "
  1520. "the resulting train set will be empty"
  1521. ),
  1522. ):
  1523. train_test_split(X, test_size=0.99)
  1524. X = [[1], [1], [1]] # 3 samples, ask for more than 2 thirds
  1525. with pytest.raises(
  1526. ValueError,
  1527. match=(
  1528. "With n_samples=3, test_size=0.67 and train_size=None, "
  1529. "the resulting train set will be empty"
  1530. ),
  1531. ):
  1532. train_test_split(X, test_size=0.67)
  1533. def test_leave_one_out_empty_trainset():
  1534. # LeaveOneGroup out expect at least 2 groups so no need to check
  1535. cv = LeaveOneOut()
  1536. X, y = [[1]], [0] # 1 sample
  1537. with pytest.raises(ValueError, match="Cannot perform LeaveOneOut with n_samples=1"):
  1538. next(cv.split(X, y))
  1539. def test_leave_p_out_empty_trainset():
  1540. # No need to check LeavePGroupsOut
  1541. cv = LeavePOut(p=2)
  1542. X, y = [[1], [2]], [0, 3] # 2 samples
  1543. with pytest.raises(
  1544. ValueError, match="p=2 must be strictly less than the number of samples=2"
  1545. ):
  1546. next(cv.split(X, y, groups=[1, 2]))
  1547. @pytest.mark.parametrize("Klass", (KFold, StratifiedKFold, StratifiedGroupKFold))
  1548. def test_random_state_shuffle_false(Klass):
  1549. # passing a non-default random_state when shuffle=False makes no sense
  1550. with pytest.raises(ValueError, match="has no effect since shuffle is False"):
  1551. Klass(3, shuffle=False, random_state=0)
  1552. @pytest.mark.parametrize(
  1553. "cv, expected",
  1554. [
  1555. (KFold(), True),
  1556. (KFold(shuffle=True, random_state=123), True),
  1557. (StratifiedKFold(), True),
  1558. (StratifiedKFold(shuffle=True, random_state=123), True),
  1559. (StratifiedGroupKFold(shuffle=True, random_state=123), True),
  1560. (StratifiedGroupKFold(), True),
  1561. (RepeatedKFold(random_state=123), True),
  1562. (RepeatedStratifiedKFold(random_state=123), True),
  1563. (ShuffleSplit(random_state=123), True),
  1564. (GroupShuffleSplit(random_state=123), True),
  1565. (StratifiedShuffleSplit(random_state=123), True),
  1566. (GroupKFold(), True),
  1567. (TimeSeriesSplit(), True),
  1568. (LeaveOneOut(), True),
  1569. (LeaveOneGroupOut(), True),
  1570. (LeavePGroupsOut(n_groups=2), True),
  1571. (LeavePOut(p=2), True),
  1572. (KFold(shuffle=True, random_state=None), False),
  1573. (KFold(shuffle=True, random_state=None), False),
  1574. (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
  1575. (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
  1576. (RepeatedKFold(random_state=None), False),
  1577. (RepeatedKFold(random_state=np.random.RandomState(0)), False),
  1578. (RepeatedStratifiedKFold(random_state=None), False),
  1579. (RepeatedStratifiedKFold(random_state=np.random.RandomState(0)), False),
  1580. (ShuffleSplit(random_state=None), False),
  1581. (ShuffleSplit(random_state=np.random.RandomState(0)), False),
  1582. (GroupShuffleSplit(random_state=None), False),
  1583. (GroupShuffleSplit(random_state=np.random.RandomState(0)), False),
  1584. (StratifiedShuffleSplit(random_state=None), False),
  1585. (StratifiedShuffleSplit(random_state=np.random.RandomState(0)), False),
  1586. ],
  1587. )
  1588. def test_yields_constant_splits(cv, expected):
  1589. assert _yields_constant_splits(cv) == expected
  1590. @pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS])
  1591. def test_splitter_get_metadata_routing(cv):
  1592. """Check get_metadata_routing returns the correct MetadataRouter."""
  1593. assert hasattr(cv, "get_metadata_routing")
  1594. metadata = cv.get_metadata_routing()
  1595. if cv in GROUP_SPLITTERS:
  1596. assert metadata.split.requests["groups"] is True
  1597. elif cv in NO_GROUP_SPLITTERS:
  1598. assert not metadata.split.requests
  1599. assert_request_is_empty(metadata, exclude=["split"])
  1600. @pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS])
  1601. def test_splitter_set_split_request(cv):
  1602. """Check set_split_request is defined for group splitters and not for others."""
  1603. if cv in GROUP_SPLITTERS:
  1604. assert hasattr(cv, "set_split_request")
  1605. elif cv in NO_GROUP_SPLITTERS:
  1606. assert not hasattr(cv, "set_split_request")