| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950 |
- """Test the split module"""
- import re
- import warnings
- from itertools import combinations, combinations_with_replacement, permutations
- import numpy as np
- import pytest
- from scipy import stats
- from scipy.sparse import (
- coo_matrix,
- csc_matrix,
- csr_matrix,
- issparse,
- )
- from scipy.special import comb
- from sklearn.datasets import load_digits, make_classification
- from sklearn.dummy import DummyClassifier
- from sklearn.model_selection import (
- GridSearchCV,
- GroupKFold,
- GroupShuffleSplit,
- KFold,
- LeaveOneGroupOut,
- LeaveOneOut,
- LeavePGroupsOut,
- LeavePOut,
- PredefinedSplit,
- RepeatedKFold,
- RepeatedStratifiedKFold,
- ShuffleSplit,
- StratifiedGroupKFold,
- StratifiedKFold,
- StratifiedShuffleSplit,
- TimeSeriesSplit,
- check_cv,
- cross_val_score,
- train_test_split,
- )
- from sklearn.model_selection._split import (
- _build_repr,
- _validate_shuffle_split,
- _yields_constant_splits,
- )
- from sklearn.svm import SVC
- from sklearn.tests.test_metadata_routing import assert_request_is_empty
- from sklearn.utils._mocking import MockDataFrame
- from sklearn.utils._testing import (
- assert_allclose,
- assert_array_almost_equal,
- assert_array_equal,
- ignore_warnings,
- )
- from sklearn.utils.validation import _num_samples
- NO_GROUP_SPLITTERS = [
- KFold(),
- StratifiedKFold(),
- TimeSeriesSplit(),
- LeaveOneOut(),
- LeavePOut(p=2),
- ShuffleSplit(),
- StratifiedShuffleSplit(test_size=0.5),
- PredefinedSplit([1, 1, 2, 2]),
- RepeatedKFold(),
- RepeatedStratifiedKFold(),
- ]
- GROUP_SPLITTERS = [
- GroupKFold(),
- LeavePGroupsOut(n_groups=1),
- StratifiedGroupKFold(),
- LeaveOneGroupOut(),
- GroupShuffleSplit(),
- ]
- ALL_SPLITTERS = NO_GROUP_SPLITTERS + GROUP_SPLITTERS # type: ignore
- X = np.ones(10)
- y = np.arange(10) // 2
- P_sparse = coo_matrix(np.eye(5))
- test_groups = (
- np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
- np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
- np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
- np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
- [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
- ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
- )
- digits = load_digits()
- @ignore_warnings
- def test_cross_validator_with_default_params():
- n_samples = 4
- n_unique_groups = 4
- n_splits = 2
- p = 2
- n_shuffle_splits = 10 # (the default value)
- X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
- X_1d = np.array([1, 2, 3, 4])
- y = np.array([1, 1, 2, 2])
- groups = np.array([1, 2, 3, 4])
- loo = LeaveOneOut()
- lpo = LeavePOut(p)
- kf = KFold(n_splits)
- skf = StratifiedKFold(n_splits)
- lolo = LeaveOneGroupOut()
- lopo = LeavePGroupsOut(p)
- ss = ShuffleSplit(random_state=0)
- ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2
- sgkf = StratifiedGroupKFold(n_splits)
- loo_repr = "LeaveOneOut()"
- lpo_repr = "LeavePOut(p=2)"
- kf_repr = "KFold(n_splits=2, random_state=None, shuffle=False)"
- skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
- lolo_repr = "LeaveOneGroupOut()"
- lopo_repr = "LeavePGroupsOut(n_groups=2)"
- ss_repr = (
- "ShuffleSplit(n_splits=10, random_state=0, test_size=None, train_size=None)"
- )
- ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"
- sgkf_repr = "StratifiedGroupKFold(n_splits=2, random_state=None, shuffle=False)"
- n_splits_expected = [
- n_samples,
- comb(n_samples, p),
- n_splits,
- n_splits,
- n_unique_groups,
- comb(n_unique_groups, p),
- n_shuffle_splits,
- 2,
- n_splits,
- ]
- for i, (cv, cv_repr) in enumerate(
- zip(
- [loo, lpo, kf, skf, lolo, lopo, ss, ps, sgkf],
- [
- loo_repr,
- lpo_repr,
- kf_repr,
- skf_repr,
- lolo_repr,
- lopo_repr,
- ss_repr,
- ps_repr,
- sgkf_repr,
- ],
- )
- ):
- # Test if get_n_splits works correctly
- assert n_splits_expected[i] == cv.get_n_splits(X, y, groups)
- # Test if the cross-validator works as expected even if
- # the data is 1d
- np.testing.assert_equal(
- list(cv.split(X, y, groups)), list(cv.split(X_1d, y, groups))
- )
- # Test that train, test indices returned are integers
- for train, test in cv.split(X, y, groups):
- assert np.asarray(train).dtype.kind == "i"
- assert np.asarray(test).dtype.kind == "i"
- # Test if the repr works without any errors
- assert cv_repr == repr(cv)
- # ValueError for get_n_splits methods
- msg = "The 'X' parameter should not be None."
- with pytest.raises(ValueError, match=msg):
- loo.get_n_splits(None, y, groups)
- with pytest.raises(ValueError, match=msg):
- lpo.get_n_splits(None, y, groups)
- def test_2d_y():
- # smoke test for 2d y and multi-label
- n_samples = 30
- rng = np.random.RandomState(1)
- X = rng.randint(0, 3, size=(n_samples, 2))
- y = rng.randint(0, 3, size=(n_samples,))
- y_2d = y.reshape(-1, 1)
- y_multilabel = rng.randint(0, 2, size=(n_samples, 3))
- groups = rng.randint(0, 3, size=(n_samples,))
- splitters = [
- LeaveOneOut(),
- LeavePOut(p=2),
- KFold(),
- StratifiedKFold(),
- RepeatedKFold(),
- RepeatedStratifiedKFold(),
- StratifiedGroupKFold(),
- ShuffleSplit(),
- StratifiedShuffleSplit(test_size=0.5),
- GroupShuffleSplit(),
- LeaveOneGroupOut(),
- LeavePGroupsOut(n_groups=2),
- GroupKFold(n_splits=3),
- TimeSeriesSplit(),
- PredefinedSplit(test_fold=groups),
- ]
- for splitter in splitters:
- list(splitter.split(X, y, groups))
- list(splitter.split(X, y_2d, groups))
- try:
- list(splitter.split(X, y_multilabel, groups))
- except ValueError as e:
- allowed_target_types = ("binary", "multiclass")
- msg = "Supported target types are: {}. Got 'multilabel".format(
- allowed_target_types
- )
- assert msg in str(e)
- def check_valid_split(train, test, n_samples=None):
- # Use python sets to get more informative assertion failure messages
- train, test = set(train), set(test)
- # Train and test split should not overlap
- assert train.intersection(test) == set()
- if n_samples is not None:
- # Check that the union of train an test split cover all the indices
- assert train.union(test) == set(range(n_samples))
- def check_cv_coverage(cv, X, y, groups, expected_n_splits):
- n_samples = _num_samples(X)
- # Check that a all the samples appear at least once in a test fold
- assert cv.get_n_splits(X, y, groups) == expected_n_splits
- collected_test_samples = set()
- iterations = 0
- for train, test in cv.split(X, y, groups):
- check_valid_split(train, test, n_samples=n_samples)
- iterations += 1
- collected_test_samples.update(test)
- # Check that the accumulated test samples cover the whole dataset
- assert iterations == expected_n_splits
- if n_samples is not None:
- assert collected_test_samples == set(range(n_samples))
- def test_kfold_valueerrors():
- X1 = np.array([[1, 2], [3, 4], [5, 6]])
- X2 = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
- # Check that errors are raised if there is not enough samples
- (ValueError, next, KFold(4).split(X1))
- # Check that a warning is raised if the least populated class has too few
- # members.
- y = np.array([3, 3, -1, -1, 3])
- skf_3 = StratifiedKFold(3)
- with pytest.warns(Warning, match="The least populated class"):
- next(skf_3.split(X2, y))
- sgkf_3 = StratifiedGroupKFold(3)
- naive_groups = np.arange(len(y))
- with pytest.warns(Warning, match="The least populated class"):
- next(sgkf_3.split(X2, y, naive_groups))
- # Check that despite the warning the folds are still computed even
- # though all the classes are not necessarily represented at on each
- # side of the split at each split
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- check_cv_coverage(skf_3, X2, y, groups=None, expected_n_splits=3)
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- check_cv_coverage(sgkf_3, X2, y, groups=naive_groups, expected_n_splits=3)
- # Check that errors are raised if all n_groups for individual
- # classes are less than n_splits.
- y = np.array([3, 3, -1, -1, 2])
- with pytest.raises(ValueError):
- next(skf_3.split(X2, y))
- with pytest.raises(ValueError):
- next(sgkf_3.split(X2, y))
- # Error when number of folds is <= 1
- with pytest.raises(ValueError):
- KFold(0)
- with pytest.raises(ValueError):
- KFold(1)
- error_string = "k-fold cross-validation requires at least one train/test split"
- with pytest.raises(ValueError, match=error_string):
- StratifiedKFold(0)
- with pytest.raises(ValueError, match=error_string):
- StratifiedKFold(1)
- with pytest.raises(ValueError, match=error_string):
- StratifiedGroupKFold(0)
- with pytest.raises(ValueError, match=error_string):
- StratifiedGroupKFold(1)
- # When n_splits is not integer:
- with pytest.raises(ValueError):
- KFold(1.5)
- with pytest.raises(ValueError):
- KFold(2.0)
- with pytest.raises(ValueError):
- StratifiedKFold(1.5)
- with pytest.raises(ValueError):
- StratifiedKFold(2.0)
- with pytest.raises(ValueError):
- StratifiedGroupKFold(1.5)
- with pytest.raises(ValueError):
- StratifiedGroupKFold(2.0)
- # When shuffle is not a bool:
- with pytest.raises(TypeError):
- KFold(n_splits=4, shuffle=None)
- def test_kfold_indices():
- # Check all indices are returned in the test folds
- X1 = np.ones(18)
- kf = KFold(3)
- check_cv_coverage(kf, X1, y=None, groups=None, expected_n_splits=3)
- # Check all indices are returned in the test folds even when equal-sized
- # folds are not possible
- X2 = np.ones(17)
- kf = KFold(3)
- check_cv_coverage(kf, X2, y=None, groups=None, expected_n_splits=3)
- # Check if get_n_splits returns the number of folds
- assert 5 == KFold(5).get_n_splits(X2)
- def test_kfold_no_shuffle():
- # Manually check that KFold preserves the data ordering on toy datasets
- X2 = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
- splits = KFold(2).split(X2[:-1])
- train, test = next(splits)
- assert_array_equal(test, [0, 1])
- assert_array_equal(train, [2, 3])
- train, test = next(splits)
- assert_array_equal(test, [2, 3])
- assert_array_equal(train, [0, 1])
- splits = KFold(2).split(X2)
- train, test = next(splits)
- assert_array_equal(test, [0, 1, 2])
- assert_array_equal(train, [3, 4])
- train, test = next(splits)
- assert_array_equal(test, [3, 4])
- assert_array_equal(train, [0, 1, 2])
- def test_stratified_kfold_no_shuffle():
- # Manually check that StratifiedKFold preserves the data ordering as much
- # as possible on toy datasets in order to avoid hiding sample dependencies
- # when possible
- X, y = np.ones(4), [1, 1, 0, 0]
- splits = StratifiedKFold(2).split(X, y)
- train, test = next(splits)
- assert_array_equal(test, [0, 2])
- assert_array_equal(train, [1, 3])
- train, test = next(splits)
- assert_array_equal(test, [1, 3])
- assert_array_equal(train, [0, 2])
- X, y = np.ones(7), [1, 1, 1, 0, 0, 0, 0]
- splits = StratifiedKFold(2).split(X, y)
- train, test = next(splits)
- assert_array_equal(test, [0, 1, 3, 4])
- assert_array_equal(train, [2, 5, 6])
- train, test = next(splits)
- assert_array_equal(test, [2, 5, 6])
- assert_array_equal(train, [0, 1, 3, 4])
- # Check if get_n_splits returns the number of folds
- assert 5 == StratifiedKFold(5).get_n_splits(X, y)
- # Make sure string labels are also supported
- X = np.ones(7)
- y1 = ["1", "1", "1", "0", "0", "0", "0"]
- y2 = [1, 1, 1, 0, 0, 0, 0]
- np.testing.assert_equal(
- list(StratifiedKFold(2).split(X, y1)), list(StratifiedKFold(2).split(X, y2))
- )
- # Check equivalence to KFold
- y = [0, 1, 0, 1, 0, 1, 0, 1]
- X = np.ones_like(y)
- np.testing.assert_equal(
- list(StratifiedKFold(3).split(X, y)), list(KFold(3).split(X, y))
- )
- @pytest.mark.parametrize("shuffle", [False, True])
- @pytest.mark.parametrize("k", [4, 5, 6, 7, 8, 9, 10])
- @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
- def test_stratified_kfold_ratios(k, shuffle, kfold):
- # Check that stratified kfold preserves class ratios in individual splits
- # Repeat with shuffling turned off and on
- n_samples = 1000
- X = np.ones(n_samples)
- y = np.array(
- [4] * int(0.10 * n_samples)
- + [0] * int(0.89 * n_samples)
- + [1] * int(0.01 * n_samples)
- )
- # ensure perfect stratification with StratifiedGroupKFold
- groups = np.arange(len(y))
- distr = np.bincount(y) / len(y)
- test_sizes = []
- random_state = None if not shuffle else 0
- skf = kfold(k, random_state=random_state, shuffle=shuffle)
- for train, test in skf.split(X, y, groups=groups):
- assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
- assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
- test_sizes.append(len(test))
- assert np.ptp(test_sizes) <= 1
- @pytest.mark.parametrize("shuffle", [False, True])
- @pytest.mark.parametrize("k", [4, 6, 7])
- @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
- def test_stratified_kfold_label_invariance(k, shuffle, kfold):
- # Check that stratified kfold gives the same indices regardless of labels
- n_samples = 100
- y = np.array(
- [2] * int(0.10 * n_samples)
- + [0] * int(0.89 * n_samples)
- + [1] * int(0.01 * n_samples)
- )
- X = np.ones(len(y))
- # ensure perfect stratification with StratifiedGroupKFold
- groups = np.arange(len(y))
- def get_splits(y):
- random_state = None if not shuffle else 0
- return [
- (list(train), list(test))
- for train, test in kfold(
- k, random_state=random_state, shuffle=shuffle
- ).split(X, y, groups=groups)
- ]
- splits_base = get_splits(y)
- for perm in permutations([0, 1, 2]):
- y_perm = np.take(perm, y)
- splits_perm = get_splits(y_perm)
- assert splits_perm == splits_base
- def test_kfold_balance():
- # Check that KFold returns folds with balanced sizes
- for i in range(11, 17):
- kf = KFold(5).split(X=np.ones(i))
- sizes = [len(test) for _, test in kf]
- assert (np.max(sizes) - np.min(sizes)) <= 1
- assert np.sum(sizes) == i
- @pytest.mark.parametrize("kfold", [StratifiedKFold, StratifiedGroupKFold])
- def test_stratifiedkfold_balance(kfold):
- # Check that KFold returns folds with balanced sizes (only when
- # stratification is possible)
- # Repeat with shuffling turned off and on
- X = np.ones(17)
- y = [0] * 3 + [1] * 14
- # ensure perfect stratification with StratifiedGroupKFold
- groups = np.arange(len(y))
- for shuffle in (True, False):
- cv = kfold(3, shuffle=shuffle)
- for i in range(11, 17):
- skf = cv.split(X[:i], y[:i], groups[:i])
- sizes = [len(test) for _, test in skf]
- assert (np.max(sizes) - np.min(sizes)) <= 1
- assert np.sum(sizes) == i
- def test_shuffle_kfold():
- # Check the indices are shuffled properly
- kf = KFold(3)
- kf2 = KFold(3, shuffle=True, random_state=0)
- kf3 = KFold(3, shuffle=True, random_state=1)
- X = np.ones(300)
- all_folds = np.zeros(300)
- for (tr1, te1), (tr2, te2), (tr3, te3) in zip(
- kf.split(X), kf2.split(X), kf3.split(X)
- ):
- for tr_a, tr_b in combinations((tr1, tr2, tr3), 2):
- # Assert that there is no complete overlap
- assert len(np.intersect1d(tr_a, tr_b)) != len(tr1)
- # Set all test indices in successive iterations of kf2 to 1
- all_folds[te2] = 1
- # Check that all indices are returned in the different test folds
- assert sum(all_folds) == 300
- @pytest.mark.parametrize("kfold", [KFold, StratifiedKFold, StratifiedGroupKFold])
- def test_shuffle_kfold_stratifiedkfold_reproducibility(kfold):
- X = np.ones(15) # Divisible by 3
- y = [0] * 7 + [1] * 8
- groups_1 = np.arange(len(y))
- X2 = np.ones(16) # Not divisible by 3
- y2 = [0] * 8 + [1] * 8
- groups_2 = np.arange(len(y2))
- # Check that when the shuffle is True, multiple split calls produce the
- # same split when random_state is int
- kf = kfold(3, shuffle=True, random_state=0)
- np.testing.assert_equal(
- list(kf.split(X, y, groups_1)), list(kf.split(X, y, groups_1))
- )
- # Check that when the shuffle is True, multiple split calls often
- # (not always) produce different splits when random_state is
- # RandomState instance or None
- kf = kfold(3, shuffle=True, random_state=np.random.RandomState(0))
- for data in zip((X, X2), (y, y2), (groups_1, groups_2)):
- # Test if the two splits are different cv
- for (_, test_a), (_, test_b) in zip(kf.split(*data), kf.split(*data)):
- # cv.split(...) returns an array of tuples, each tuple
- # consisting of an array with train indices and test indices
- # Ensure that the splits for data are not same
- # when random state is not set
- with pytest.raises(AssertionError):
- np.testing.assert_array_equal(test_a, test_b)
- def test_shuffle_stratifiedkfold():
- # Check that shuffling is happening when requested, and for proper
- # sample coverage
- X_40 = np.ones(40)
- y = [0] * 20 + [1] * 20
- kf0 = StratifiedKFold(5, shuffle=True, random_state=0)
- kf1 = StratifiedKFold(5, shuffle=True, random_state=1)
- for (_, test0), (_, test1) in zip(kf0.split(X_40, y), kf1.split(X_40, y)):
- assert set(test0) != set(test1)
- check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
- # Ensure that we shuffle each class's samples with different
- # random_state in StratifiedKFold
- # See https://github.com/scikit-learn/scikit-learn/pull/13124
- X = np.arange(10)
- y = [0] * 5 + [1] * 5
- kf1 = StratifiedKFold(5, shuffle=True, random_state=0)
- kf2 = StratifiedKFold(5, shuffle=True, random_state=1)
- test_set1 = sorted([tuple(s[1]) for s in kf1.split(X, y)])
- test_set2 = sorted([tuple(s[1]) for s in kf2.split(X, y)])
- assert test_set1 != test_set2
- def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
- # The digits samples are dependent: they are apparently grouped by authors
- # although we don't have any information on the groups segment locations
- # for this data. We can highlight this fact by computing k-fold cross-
- # validation with and without shuffling: we observe that the shuffling case
- # wrongly makes the IID assumption and is therefore too optimistic: it
- # estimates a much higher accuracy (around 0.93) than that the non
- # shuffling variant (around 0.81).
- X, y = digits.data[:600], digits.target[:600]
- model = SVC(C=10, gamma=0.005)
- n_splits = 3
- cv = KFold(n_splits=n_splits, shuffle=False)
- mean_score = cross_val_score(model, X, y, cv=cv).mean()
- assert 0.92 > mean_score
- assert mean_score > 0.80
- # Shuffling the data artificially breaks the dependency and hides the
- # overfitting of the model with regards to the writing style of the authors
- # by yielding a seriously overestimated score:
- cv = KFold(n_splits, shuffle=True, random_state=0)
- mean_score = cross_val_score(model, X, y, cv=cv).mean()
- assert mean_score > 0.92
- cv = KFold(n_splits, shuffle=True, random_state=1)
- mean_score = cross_val_score(model, X, y, cv=cv).mean()
- assert mean_score > 0.92
- # Similarly, StratifiedKFold should try to shuffle the data as little
- # as possible (while respecting the balanced class constraints)
- # and thus be able to detect the dependency by not overestimating
- # the CV score either. As the digits dataset is approximately balanced
- # the estimated mean score is close to the score measured with
- # non-shuffled KFold
- cv = StratifiedKFold(n_splits)
- mean_score = cross_val_score(model, X, y, cv=cv).mean()
- assert 0.94 > mean_score
- assert mean_score > 0.80
- def test_stratified_group_kfold_trivial():
- sgkf = StratifiedGroupKFold(n_splits=3)
- # Trivial example - groups with the same distribution
- y = np.array([1] * 6 + [0] * 12)
- X = np.ones_like(y).reshape(-1, 1)
- groups = np.asarray((1, 2, 3, 4, 5, 6, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6))
- distr = np.bincount(y) / len(y)
- test_sizes = []
- for train, test in sgkf.split(X, y, groups):
- # check group constraint
- assert np.intersect1d(groups[train], groups[test]).size == 0
- # check y distribution
- assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
- assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
- test_sizes.append(len(test))
- assert np.ptp(test_sizes) <= 1
- def test_stratified_group_kfold_approximate():
- # Not perfect stratification (even though it is possible) because of
- # iteration over groups
- sgkf = StratifiedGroupKFold(n_splits=3)
- y = np.array([1] * 6 + [0] * 12)
- X = np.ones_like(y).reshape(-1, 1)
- groups = np.array([1, 2, 3, 3, 4, 4, 1, 1, 2, 2, 3, 4, 5, 5, 5, 6, 6, 6])
- expected = np.asarray([[0.833, 0.166], [0.666, 0.333], [0.5, 0.5]])
- test_sizes = []
- for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
- # check group constraint
- assert np.intersect1d(groups[train], groups[test]).size == 0
- split_dist = np.bincount(y[test]) / len(test)
- assert_allclose(split_dist, expect_dist, atol=0.001)
- test_sizes.append(len(test))
- assert np.ptp(test_sizes) <= 1
- @pytest.mark.parametrize(
- "y, groups, expected",
- [
- (
- np.array([0] * 6 + [1] * 6),
- np.array([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
- np.asarray([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]),
- ),
- (
- np.array([0] * 9 + [1] * 3),
- np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6]),
- np.asarray([[0.75, 0.25], [0.75, 0.25], [0.75, 0.25]]),
- ),
- ],
- )
- def test_stratified_group_kfold_homogeneous_groups(y, groups, expected):
- sgkf = StratifiedGroupKFold(n_splits=3)
- X = np.ones_like(y).reshape(-1, 1)
- for (train, test), expect_dist in zip(sgkf.split(X, y, groups), expected):
- # check group constraint
- assert np.intersect1d(groups[train], groups[test]).size == 0
- split_dist = np.bincount(y[test]) / len(test)
- assert_allclose(split_dist, expect_dist, atol=0.001)
- @pytest.mark.parametrize("cls_distr", [(0.4, 0.6), (0.3, 0.7), (0.2, 0.8), (0.8, 0.2)])
- @pytest.mark.parametrize("n_groups", [5, 30, 70])
- def test_stratified_group_kfold_against_group_kfold(cls_distr, n_groups):
- # Check that given sufficient amount of samples StratifiedGroupKFold
- # produces better stratified folds than regular GroupKFold
- n_splits = 5
- sgkf = StratifiedGroupKFold(n_splits=n_splits)
- gkf = GroupKFold(n_splits=n_splits)
- rng = np.random.RandomState(0)
- n_points = 1000
- y = rng.choice(2, size=n_points, p=cls_distr)
- X = np.ones_like(y).reshape(-1, 1)
- g = rng.choice(n_groups, n_points)
- sgkf_folds = sgkf.split(X, y, groups=g)
- gkf_folds = gkf.split(X, y, groups=g)
- sgkf_entr = 0
- gkf_entr = 0
- for (sgkf_train, sgkf_test), (_, gkf_test) in zip(sgkf_folds, gkf_folds):
- # check group constraint
- assert np.intersect1d(g[sgkf_train], g[sgkf_test]).size == 0
- sgkf_distr = np.bincount(y[sgkf_test]) / len(sgkf_test)
- gkf_distr = np.bincount(y[gkf_test]) / len(gkf_test)
- sgkf_entr += stats.entropy(sgkf_distr, qk=cls_distr)
- gkf_entr += stats.entropy(gkf_distr, qk=cls_distr)
- sgkf_entr /= n_splits
- gkf_entr /= n_splits
- assert sgkf_entr <= gkf_entr
- def test_shuffle_split():
- ss1 = ShuffleSplit(test_size=0.2, random_state=0).split(X)
- ss2 = ShuffleSplit(test_size=2, random_state=0).split(X)
- ss3 = ShuffleSplit(test_size=np.int32(2), random_state=0).split(X)
- ss4 = ShuffleSplit(test_size=int(2), random_state=0).split(X)
- for t1, t2, t3, t4 in zip(ss1, ss2, ss3, ss4):
- assert_array_equal(t1[0], t2[0])
- assert_array_equal(t2[0], t3[0])
- assert_array_equal(t3[0], t4[0])
- assert_array_equal(t1[1], t2[1])
- assert_array_equal(t2[1], t3[1])
- assert_array_equal(t3[1], t4[1])
- @pytest.mark.parametrize("split_class", [ShuffleSplit, StratifiedShuffleSplit])
- @pytest.mark.parametrize(
- "train_size, exp_train, exp_test", [(None, 9, 1), (8, 8, 2), (0.8, 8, 2)]
- )
- def test_shuffle_split_default_test_size(split_class, train_size, exp_train, exp_test):
- # Check that the default value has the expected behavior, i.e. 0.1 if both
- # unspecified or complement train_size unless both are specified.
- X = np.ones(10)
- y = np.ones(10)
- X_train, X_test = next(split_class(train_size=train_size).split(X, y))
- assert len(X_train) == exp_train
- assert len(X_test) == exp_test
- @pytest.mark.parametrize(
- "train_size, exp_train, exp_test", [(None, 8, 2), (7, 7, 3), (0.7, 7, 3)]
- )
- def test_group_shuffle_split_default_test_size(train_size, exp_train, exp_test):
- # Check that the default value has the expected behavior, i.e. 0.2 if both
- # unspecified or complement train_size unless both are specified.
- X = np.ones(10)
- y = np.ones(10)
- groups = range(10)
- X_train, X_test = next(GroupShuffleSplit(train_size=train_size).split(X, y, groups))
- assert len(X_train) == exp_train
- assert len(X_test) == exp_test
- @ignore_warnings
- def test_stratified_shuffle_split_init():
- X = np.arange(7)
- y = np.asarray([0, 1, 1, 1, 2, 2, 2])
- # Check that error is raised if there is a class with only one sample
- with pytest.raises(ValueError):
- next(StratifiedShuffleSplit(3, test_size=0.2).split(X, y))
- # Check that error is raised if the test set size is smaller than n_classes
- with pytest.raises(ValueError):
- next(StratifiedShuffleSplit(3, test_size=2).split(X, y))
- # Check that error is raised if the train set size is smaller than
- # n_classes
- with pytest.raises(ValueError):
- next(StratifiedShuffleSplit(3, test_size=3, train_size=2).split(X, y))
- X = np.arange(9)
- y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])
- # Train size or test size too small
- with pytest.raises(ValueError):
- next(StratifiedShuffleSplit(train_size=2).split(X, y))
- with pytest.raises(ValueError):
- next(StratifiedShuffleSplit(test_size=2).split(X, y))
- def test_stratified_shuffle_split_respects_test_size():
- y = np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2])
- test_size = 5
- train_size = 10
- sss = StratifiedShuffleSplit(
- 6, test_size=test_size, train_size=train_size, random_state=0
- ).split(np.ones(len(y)), y)
- for train, test in sss:
- assert len(train) == train_size
- assert len(test) == test_size
- def test_stratified_shuffle_split_iter():
- ys = [
- np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
- np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
- np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
- np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
- np.array([-1] * 800 + [1] * 50),
- np.concatenate([[i] * (100 + i) for i in range(11)]),
- [1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3],
- ["1", "1", "1", "1", "2", "2", "2", "3", "3", "3", "3", "3"],
- ]
- for y in ys:
- sss = StratifiedShuffleSplit(6, test_size=0.33, random_state=0).split(
- np.ones(len(y)), y
- )
- y = np.asanyarray(y) # To make it indexable for y[train]
- # this is how test-size is computed internally
- # in _validate_shuffle_split
- test_size = np.ceil(0.33 * len(y))
- train_size = len(y) - test_size
- for train, test in sss:
- assert_array_equal(np.unique(y[train]), np.unique(y[test]))
- # Checks if folds keep classes proportions
- p_train = np.bincount(np.unique(y[train], return_inverse=True)[1]) / float(
- len(y[train])
- )
- p_test = np.bincount(np.unique(y[test], return_inverse=True)[1]) / float(
- len(y[test])
- )
- assert_array_almost_equal(p_train, p_test, 1)
- assert len(train) + len(test) == y.size
- assert len(train) == train_size
- assert len(test) == test_size
- assert_array_equal(np.intersect1d(train, test), [])
- def test_stratified_shuffle_split_even():
- # Test the StratifiedShuffleSplit, indices are drawn with a
- # equal chance
- n_folds = 5
- n_splits = 1000
- def assert_counts_are_ok(idx_counts, p):
- # Here we test that the distribution of the counts
- # per index is close enough to a binomial
- threshold = 0.05 / n_splits
- bf = stats.binom(n_splits, p)
- for count in idx_counts:
- prob = bf.pmf(count)
- assert (
- prob > threshold
- ), "An index is not drawn with chance corresponding to even draws"
- for n_samples in (6, 22):
- groups = np.array((n_samples // 2) * [0, 1])
- splits = StratifiedShuffleSplit(
- n_splits=n_splits, test_size=1.0 / n_folds, random_state=0
- )
- train_counts = [0] * n_samples
- test_counts = [0] * n_samples
- n_splits_actual = 0
- for train, test in splits.split(X=np.ones(n_samples), y=groups):
- n_splits_actual += 1
- for counter, ids in [(train_counts, train), (test_counts, test)]:
- for id in ids:
- counter[id] += 1
- assert n_splits_actual == n_splits
- n_train, n_test = _validate_shuffle_split(
- n_samples, test_size=1.0 / n_folds, train_size=1.0 - (1.0 / n_folds)
- )
- assert len(train) == n_train
- assert len(test) == n_test
- assert len(set(train).intersection(test)) == 0
- group_counts = np.unique(groups)
- assert splits.test_size == 1.0 / n_folds
- assert n_train + n_test == len(groups)
- assert len(group_counts) == 2
- ex_test_p = float(n_test) / n_samples
- ex_train_p = float(n_train) / n_samples
- assert_counts_are_ok(train_counts, ex_train_p)
- assert_counts_are_ok(test_counts, ex_test_p)
- def test_stratified_shuffle_split_overlap_train_test_bug():
- # See https://github.com/scikit-learn/scikit-learn/issues/6121 for
- # the original bug report
- y = [0, 1, 2, 3] * 3 + [4, 5] * 5
- X = np.ones_like(y)
- sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
- train, test = next(sss.split(X=X, y=y))
- # no overlap
- assert_array_equal(np.intersect1d(train, test), [])
- # complete partition
- assert_array_equal(np.union1d(train, test), np.arange(len(y)))
- def test_stratified_shuffle_split_multilabel():
- # fix for issue 9037
- for y in [
- np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
- np.array([[0, 1], [1, 1], [1, 1], [0, 1]]),
- ]:
- X = np.ones_like(y)
- sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
- train, test = next(sss.split(X=X, y=y))
- y_train = y[train]
- y_test = y[test]
- # no overlap
- assert_array_equal(np.intersect1d(train, test), [])
- # complete partition
- assert_array_equal(np.union1d(train, test), np.arange(len(y)))
- # correct stratification of entire rows
- # (by design, here y[:, 0] uniquely determines the entire row of y)
- expected_ratio = np.mean(y[:, 0])
- assert expected_ratio == np.mean(y_train[:, 0])
- assert expected_ratio == np.mean(y_test[:, 0])
- def test_stratified_shuffle_split_multilabel_many_labels():
- # fix in PR #9922: for multilabel data with > 1000 labels, str(row)
- # truncates with an ellipsis for elements in positions 4 through
- # len(row) - 4, so labels were not being correctly split using the powerset
- # method for transforming a multilabel problem to a multiclass one; this
- # test checks that this problem is fixed.
- row_with_many_zeros = [1, 0, 1] + [0] * 1000 + [1, 0, 1]
- row_with_many_ones = [1, 0, 1] + [1] * 1000 + [1, 0, 1]
- y = np.array([row_with_many_zeros] * 10 + [row_with_many_ones] * 100)
- X = np.ones_like(y)
- sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
- train, test = next(sss.split(X=X, y=y))
- y_train = y[train]
- y_test = y[test]
- # correct stratification of entire rows
- # (by design, here y[:, 4] uniquely determines the entire row of y)
- expected_ratio = np.mean(y[:, 4])
- assert expected_ratio == np.mean(y_train[:, 4])
- assert expected_ratio == np.mean(y_test[:, 4])
- def test_predefinedsplit_with_kfold_split():
- # Check that PredefinedSplit can reproduce a split generated by Kfold.
- folds = np.full(10, -1.0)
- kf_train = []
- kf_test = []
- for i, (train_ind, test_ind) in enumerate(KFold(5, shuffle=True).split(X)):
- kf_train.append(train_ind)
- kf_test.append(test_ind)
- folds[test_ind] = i
- ps = PredefinedSplit(folds)
- # n_splits is simply the no of unique folds
- assert len(np.unique(folds)) == ps.get_n_splits()
- ps_train, ps_test = zip(*ps.split())
- assert_array_equal(ps_train, kf_train)
- assert_array_equal(ps_test, kf_test)
- def test_group_shuffle_split():
- for groups_i in test_groups:
- X = y = np.ones(len(groups_i))
- n_splits = 6
- test_size = 1.0 / 3
- slo = GroupShuffleSplit(n_splits, test_size=test_size, random_state=0)
- # Make sure the repr works
- repr(slo)
- # Test that the length is correct
- assert slo.get_n_splits(X, y, groups=groups_i) == n_splits
- l_unique = np.unique(groups_i)
- l = np.asarray(groups_i)
- for train, test in slo.split(X, y, groups=groups_i):
- # First test: no train group is in the test set and vice versa
- l_train_unique = np.unique(l[train])
- l_test_unique = np.unique(l[test])
- assert not np.any(np.isin(l[train], l_test_unique))
- assert not np.any(np.isin(l[test], l_train_unique))
- # Second test: train and test add up to all the data
- assert l[train].size + l[test].size == l.size
- # Third test: train and test are disjoint
- assert_array_equal(np.intersect1d(train, test), [])
- # Fourth test:
- # unique train and test groups are correct, +- 1 for rounding error
- assert abs(len(l_test_unique) - round(test_size * len(l_unique))) <= 1
- assert (
- abs(len(l_train_unique) - round((1.0 - test_size) * len(l_unique))) <= 1
- )
- def test_leave_one_p_group_out():
- logo = LeaveOneGroupOut()
- lpgo_1 = LeavePGroupsOut(n_groups=1)
- lpgo_2 = LeavePGroupsOut(n_groups=2)
- # Make sure the repr works
- assert repr(logo) == "LeaveOneGroupOut()"
- assert repr(lpgo_1) == "LeavePGroupsOut(n_groups=1)"
- assert repr(lpgo_2) == "LeavePGroupsOut(n_groups=2)"
- assert repr(LeavePGroupsOut(n_groups=3)) == "LeavePGroupsOut(n_groups=3)"
- for j, (cv, p_groups_out) in enumerate(((logo, 1), (lpgo_1, 1), (lpgo_2, 2))):
- for i, groups_i in enumerate(test_groups):
- n_groups = len(np.unique(groups_i))
- n_splits = n_groups if p_groups_out == 1 else n_groups * (n_groups - 1) / 2
- X = y = np.ones(len(groups_i))
- # Test that the length is correct
- assert cv.get_n_splits(X, y, groups=groups_i) == n_splits
- groups_arr = np.asarray(groups_i)
- # Split using the original list / array / list of string groups_i
- for train, test in cv.split(X, y, groups=groups_i):
- # First test: no train group is in the test set and vice versa
- assert_array_equal(
- np.intersect1d(groups_arr[train], groups_arr[test]).tolist(), []
- )
- # Second test: train and test add up to all the data
- assert len(train) + len(test) == len(groups_i)
- # Third test:
- # The number of groups in test must be equal to p_groups_out
- assert np.unique(groups_arr[test]).shape[0], p_groups_out
- # check get_n_splits() with dummy parameters
- assert logo.get_n_splits(None, None, ["a", "b", "c", "b", "c"]) == 3
- assert logo.get_n_splits(groups=[1.0, 1.1, 1.0, 1.2]) == 3
- assert lpgo_2.get_n_splits(None, None, np.arange(4)) == 6
- assert lpgo_1.get_n_splits(groups=np.arange(4)) == 4
- # raise ValueError if a `groups` parameter is illegal
- with pytest.raises(ValueError):
- logo.get_n_splits(None, None, [0.0, np.nan, 0.0])
- with pytest.raises(ValueError):
- lpgo_2.get_n_splits(None, None, [0.0, np.inf, 0.0])
- msg = "The 'groups' parameter should not be None."
- with pytest.raises(ValueError, match=msg):
- logo.get_n_splits(None, None, None)
- with pytest.raises(ValueError, match=msg):
- lpgo_1.get_n_splits(None, None, None)
- def test_leave_group_out_changing_groups():
- # Check that LeaveOneGroupOut and LeavePGroupsOut work normally if
- # the groups variable is changed before calling split
- groups = np.array([0, 1, 2, 1, 1, 2, 0, 0])
- X = np.ones(len(groups))
- groups_changing = np.array(groups, copy=True)
- lolo = LeaveOneGroupOut().split(X, groups=groups)
- lolo_changing = LeaveOneGroupOut().split(X, groups=groups)
- lplo = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
- lplo_changing = LeavePGroupsOut(n_groups=2).split(X, groups=groups)
- groups_changing[:] = 0
- for llo, llo_changing in [(lolo, lolo_changing), (lplo, lplo_changing)]:
- for (train, test), (train_chan, test_chan) in zip(llo, llo_changing):
- assert_array_equal(train, train_chan)
- assert_array_equal(test, test_chan)
- # n_splits = no of 2 (p) group combinations of the unique groups = 3C2 = 3
- assert 3 == LeavePGroupsOut(n_groups=2).get_n_splits(X, y=X, groups=groups)
- # n_splits = no of unique groups (C(uniq_lbls, 1) = n_unique_groups)
- assert 3 == LeaveOneGroupOut().get_n_splits(X, y=X, groups=groups)
- def test_leave_group_out_order_dependence():
- # Check that LeaveOneGroupOut orders the splits according to the index
- # of the group left out.
- groups = np.array([2, 2, 0, 0, 1, 1])
- X = np.ones(len(groups))
- splits = iter(LeaveOneGroupOut().split(X, groups=groups))
- expected_indices = [
- ([0, 1, 4, 5], [2, 3]),
- ([0, 1, 2, 3], [4, 5]),
- ([2, 3, 4, 5], [0, 1]),
- ]
- for expected_train, expected_test in expected_indices:
- train, test = next(splits)
- assert_array_equal(train, expected_train)
- assert_array_equal(test, expected_test)
- def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
- X = y = groups = np.ones(0)
- msg = re.escape("Found array with 0 sample(s)")
- with pytest.raises(ValueError, match=msg):
- next(LeaveOneGroupOut().split(X, y, groups))
- X = y = groups = np.ones(1)
- msg = re.escape(
- f"The groups parameter contains fewer than 2 unique groups ({groups})."
- " LeaveOneGroupOut expects at least 2."
- )
- with pytest.raises(ValueError, match=msg):
- next(LeaveOneGroupOut().split(X, y, groups))
- X = y = groups = np.ones(1)
- msg = re.escape(
- "The groups parameter contains fewer than (or equal to) n_groups "
- f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
- "that at least n_groups + 1 (4) unique groups "
- "be present"
- )
- with pytest.raises(ValueError, match=msg):
- next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
- X = y = groups = np.arange(3)
- msg = re.escape(
- "The groups parameter contains fewer than (or equal to) n_groups "
- f"(3) numbers of unique groups ({groups}). LeavePGroupsOut expects "
- "that at least n_groups + 1 (4) unique groups "
- "be present"
- )
- with pytest.raises(ValueError, match=msg):
- next(LeavePGroupsOut(n_groups=3).split(X, y, groups))
- @ignore_warnings
- def test_repeated_cv_value_errors():
- # n_repeats is not integer or <= 0
- for cv in (RepeatedKFold, RepeatedStratifiedKFold):
- with pytest.raises(ValueError):
- cv(n_repeats=0)
- with pytest.raises(ValueError):
- cv(n_repeats=1.5)
- @pytest.mark.parametrize("RepeatedCV", [RepeatedKFold, RepeatedStratifiedKFold])
- def test_repeated_cv_repr(RepeatedCV):
- n_splits, n_repeats = 2, 6
- repeated_cv = RepeatedCV(n_splits=n_splits, n_repeats=n_repeats)
- repeated_cv_repr = "{}(n_repeats=6, n_splits=2, random_state=None)".format(
- repeated_cv.__class__.__name__
- )
- assert repeated_cv_repr == repr(repeated_cv)
- def test_repeated_kfold_determinstic_split():
- X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
- random_state = 258173307
- rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
- # split should produce same and deterministic splits on
- # each call
- for _ in range(3):
- splits = rkf.split(X)
- train, test = next(splits)
- assert_array_equal(train, [2, 4])
- assert_array_equal(test, [0, 1, 3])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 3])
- assert_array_equal(test, [2, 4])
- train, test = next(splits)
- assert_array_equal(train, [0, 1])
- assert_array_equal(test, [2, 3, 4])
- train, test = next(splits)
- assert_array_equal(train, [2, 3, 4])
- assert_array_equal(test, [0, 1])
- with pytest.raises(StopIteration):
- next(splits)
- def test_get_n_splits_for_repeated_kfold():
- n_splits = 3
- n_repeats = 4
- rkf = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats)
- expected_n_splits = n_splits * n_repeats
- assert expected_n_splits == rkf.get_n_splits()
- def test_get_n_splits_for_repeated_stratified_kfold():
- n_splits = 3
- n_repeats = 4
- rskf = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats)
- expected_n_splits = n_splits * n_repeats
- assert expected_n_splits == rskf.get_n_splits()
- def test_repeated_stratified_kfold_determinstic_split():
- X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
- y = [1, 1, 1, 0, 0]
- random_state = 1944695409
- rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, random_state=random_state)
- # split should produce same and deterministic splits on
- # each call
- for _ in range(3):
- splits = rskf.split(X, y)
- train, test = next(splits)
- assert_array_equal(train, [1, 4])
- assert_array_equal(test, [0, 2, 3])
- train, test = next(splits)
- assert_array_equal(train, [0, 2, 3])
- assert_array_equal(test, [1, 4])
- train, test = next(splits)
- assert_array_equal(train, [2, 3])
- assert_array_equal(test, [0, 1, 4])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 4])
- assert_array_equal(test, [2, 3])
- with pytest.raises(StopIteration):
- next(splits)
- def test_train_test_split_errors():
- pytest.raises(ValueError, train_test_split)
- pytest.raises(ValueError, train_test_split, range(3), train_size=1.1)
- pytest.raises(ValueError, train_test_split, range(3), test_size=0.6, train_size=0.6)
- pytest.raises(
- ValueError,
- train_test_split,
- range(3),
- test_size=np.float32(0.6),
- train_size=np.float32(0.6),
- )
- pytest.raises(ValueError, train_test_split, range(3), test_size="wrong_type")
- pytest.raises(ValueError, train_test_split, range(3), test_size=2, train_size=4)
- pytest.raises(TypeError, train_test_split, range(3), some_argument=1.1)
- pytest.raises(ValueError, train_test_split, range(3), range(42))
- pytest.raises(ValueError, train_test_split, range(10), shuffle=False, stratify=True)
- with pytest.raises(
- ValueError,
- match=r"train_size=11 should be either positive and "
- r"smaller than the number of samples 10 or a "
- r"float in the \(0, 1\) range",
- ):
- train_test_split(range(10), train_size=11, test_size=1)
- @pytest.mark.parametrize(
- "train_size, exp_train, exp_test", [(None, 7, 3), (8, 8, 2), (0.8, 8, 2)]
- )
- def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
- # Check that the default value has the expected behavior, i.e. complement
- # train_size unless both are specified.
- X_train, X_test = train_test_split(X, train_size=train_size)
- assert len(X_train) == exp_train
- assert len(X_test) == exp_test
- def test_train_test_split():
- X = np.arange(100).reshape((10, 10))
- X_s = coo_matrix(X)
- y = np.arange(10)
- # simple test
- split = train_test_split(X, y, test_size=None, train_size=0.5)
- X_train, X_test, y_train, y_test = split
- assert len(y_test) == len(y_train)
- # test correspondence of X and y
- assert_array_equal(X_train[:, 0], y_train * 10)
- assert_array_equal(X_test[:, 0], y_test * 10)
- # don't convert lists to anything else by default
- split = train_test_split(X, X_s, y.tolist())
- X_train, X_test, X_s_train, X_s_test, y_train, y_test = split
- assert isinstance(y_train, list)
- assert isinstance(y_test, list)
- # allow nd-arrays
- X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
- y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
- split = train_test_split(X_4d, y_3d)
- assert split[0].shape == (7, 5, 3, 2)
- assert split[1].shape == (3, 5, 3, 2)
- assert split[2].shape == (7, 7, 11)
- assert split[3].shape == (3, 7, 11)
- # test stratification option
- y = np.array([1, 1, 1, 1, 2, 2, 2, 2])
- for test_size, exp_test_size in zip([2, 4, 0.25, 0.5, 0.75], [2, 4, 2, 4, 6]):
- train, test = train_test_split(
- y, test_size=test_size, stratify=y, random_state=0
- )
- assert len(test) == exp_test_size
- assert len(test) + len(train) == len(y)
- # check the 1:1 ratio of ones and twos in the data is preserved
- assert np.sum(train == 1) == np.sum(train == 2)
- # test unshuffled split
- y = np.arange(10)
- for test_size in [2, 0.2]:
- train, test = train_test_split(y, shuffle=False, test_size=test_size)
- assert_array_equal(test, [8, 9])
- assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])
- def test_train_test_split_32bit_overflow():
- """Check for integer overflow on 32-bit platforms.
- Non-regression test for:
- https://github.com/scikit-learn/scikit-learn/issues/20774
- """
- # A number 'n' big enough for expression 'n * n * train_size' to cause
- # an overflow for signed 32-bit integer
- big_number = 100000
- # Definition of 'y' is a part of reproduction - population for at least
- # one class should be in the same order of magnitude as size of X
- X = np.arange(big_number)
- y = X > (0.99 * big_number)
- split = train_test_split(X, y, stratify=y, train_size=0.25)
- X_train, X_test, y_train, y_test = split
- assert X_train.size + X_test.size == big_number
- assert y_train.size + y_test.size == big_number
- @ignore_warnings
- def test_train_test_split_pandas():
- # check train_test_split doesn't destroy pandas dataframe
- types = [MockDataFrame]
- try:
- from pandas import DataFrame
- types.append(DataFrame)
- except ImportError:
- pass
- for InputFeatureType in types:
- # X dataframe
- X_df = InputFeatureType(X)
- X_train, X_test = train_test_split(X_df)
- assert isinstance(X_train, InputFeatureType)
- assert isinstance(X_test, InputFeatureType)
- def test_train_test_split_sparse():
- # check that train_test_split converts scipy sparse matrices
- # to csr, as stated in the documentation
- X = np.arange(100).reshape((10, 10))
- sparse_types = [csr_matrix, csc_matrix, coo_matrix]
- for InputFeatureType in sparse_types:
- X_s = InputFeatureType(X)
- X_train, X_test = train_test_split(X_s)
- assert issparse(X_train) and X_train.format == "csr"
- assert issparse(X_test) and X_test.format == "csr"
- def test_train_test_split_mock_pandas():
- # X mock dataframe
- X_df = MockDataFrame(X)
- X_train, X_test = train_test_split(X_df)
- assert isinstance(X_train, MockDataFrame)
- assert isinstance(X_test, MockDataFrame)
- X_train_arr, X_test_arr = train_test_split(X_df)
- def test_train_test_split_list_input():
- # Check that when y is a list / list of string labels, it works.
- X = np.ones(7)
- y1 = ["1"] * 4 + ["0"] * 3
- y2 = np.hstack((np.ones(4), np.zeros(3)))
- y3 = y2.tolist()
- for stratify in (True, False):
- X_train1, X_test1, y_train1, y_test1 = train_test_split(
- X, y1, stratify=y1 if stratify else None, random_state=0
- )
- X_train2, X_test2, y_train2, y_test2 = train_test_split(
- X, y2, stratify=y2 if stratify else None, random_state=0
- )
- X_train3, X_test3, y_train3, y_test3 = train_test_split(
- X, y3, stratify=y3 if stratify else None, random_state=0
- )
- np.testing.assert_equal(X_train1, X_train2)
- np.testing.assert_equal(y_train2, y_train3)
- np.testing.assert_equal(X_test1, X_test3)
- np.testing.assert_equal(y_test3, y_test2)
- @pytest.mark.parametrize(
- "test_size, train_size",
- [(2.0, None), (1.0, None), (0.1, 0.95), (None, 1j), (11, None), (10, None), (8, 3)],
- )
- def test_shufflesplit_errors(test_size, train_size):
- with pytest.raises(ValueError):
- next(ShuffleSplit(test_size=test_size, train_size=train_size).split(X))
- def test_shufflesplit_reproducible():
- # Check that iterating twice on the ShuffleSplit gives the same
- # sequence of train-test when the random_state is given
- ss = ShuffleSplit(random_state=21)
- assert_array_equal([a for a, b in ss.split(X)], [a for a, b in ss.split(X)])
- def test_stratifiedshufflesplit_list_input():
- # Check that when y is a list / list of string labels, it works.
- sss = StratifiedShuffleSplit(test_size=2, random_state=42)
- X = np.ones(7)
- y1 = ["1"] * 4 + ["0"] * 3
- y2 = np.hstack((np.ones(4), np.zeros(3)))
- y3 = y2.tolist()
- np.testing.assert_equal(list(sss.split(X, y1)), list(sss.split(X, y2)))
- np.testing.assert_equal(list(sss.split(X, y3)), list(sss.split(X, y2)))
- def test_train_test_split_allow_nans():
- # Check that train_test_split allows input data with NaNs
- X = np.arange(200, dtype=np.float64).reshape(10, -1)
- X[2, :] = np.nan
- y = np.repeat([0, 1], X.shape[0] / 2)
- train_test_split(X, y, test_size=0.2, random_state=42)
- def test_check_cv():
- X = np.ones(9)
- cv = check_cv(3, classifier=False)
- # Use numpy.testing.assert_equal which recursively compares
- # lists of lists
- np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
- y_binary = np.array([0, 1, 0, 1, 0, 0, 1, 1, 1])
- cv = check_cv(3, y_binary, classifier=True)
- np.testing.assert_equal(
- list(StratifiedKFold(3).split(X, y_binary)), list(cv.split(X, y_binary))
- )
- y_multiclass = np.array([0, 1, 0, 1, 2, 1, 2, 0, 2])
- cv = check_cv(3, y_multiclass, classifier=True)
- np.testing.assert_equal(
- list(StratifiedKFold(3).split(X, y_multiclass)), list(cv.split(X, y_multiclass))
- )
- # also works with 2d multiclass
- y_multiclass_2d = y_multiclass.reshape(-1, 1)
- cv = check_cv(3, y_multiclass_2d, classifier=True)
- np.testing.assert_equal(
- list(StratifiedKFold(3).split(X, y_multiclass_2d)),
- list(cv.split(X, y_multiclass_2d)),
- )
- assert not np.all(
- next(StratifiedKFold(3).split(X, y_multiclass_2d))[0]
- == next(KFold(3).split(X, y_multiclass_2d))[0]
- )
- X = np.ones(5)
- y_multilabel = np.array(
- [[0, 0, 0, 0], [0, 1, 1, 0], [0, 0, 0, 1], [1, 1, 0, 1], [0, 0, 1, 0]]
- )
- cv = check_cv(3, y_multilabel, classifier=True)
- np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
- y_multioutput = np.array([[1, 2], [0, 3], [0, 0], [3, 1], [2, 0]])
- cv = check_cv(3, y_multioutput, classifier=True)
- np.testing.assert_equal(list(KFold(3).split(X)), list(cv.split(X)))
- with pytest.raises(ValueError):
- check_cv(cv="lolo")
- def test_cv_iterable_wrapper():
- kf_iter = KFold().split(X, y)
- kf_iter_wrapped = check_cv(kf_iter)
- # Since the wrapped iterable is enlisted and stored,
- # split can be called any number of times to produce
- # consistent results.
- np.testing.assert_equal(
- list(kf_iter_wrapped.split(X, y)), list(kf_iter_wrapped.split(X, y))
- )
- # If the splits are randomized, successive calls to split yields different
- # results
- kf_randomized_iter = KFold(shuffle=True, random_state=0).split(X, y)
- kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
- # numpy's assert_array_equal properly compares nested lists
- np.testing.assert_equal(
- list(kf_randomized_iter_wrapped.split(X, y)),
- list(kf_randomized_iter_wrapped.split(X, y)),
- )
- try:
- splits_are_equal = True
- np.testing.assert_equal(
- list(kf_iter_wrapped.split(X, y)),
- list(kf_randomized_iter_wrapped.split(X, y)),
- )
- except AssertionError:
- splits_are_equal = False
- assert not splits_are_equal, (
- "If the splits are randomized, "
- "successive calls to split should yield different results"
- )
- @pytest.mark.parametrize("kfold", [GroupKFold, StratifiedGroupKFold])
- def test_group_kfold(kfold):
- rng = np.random.RandomState(0)
- # Parameters of the test
- n_groups = 15
- n_samples = 1000
- n_splits = 5
- X = y = np.ones(n_samples)
- # Construct the test data
- tolerance = 0.05 * n_samples # 5 percent error allowed
- groups = rng.randint(0, n_groups, n_samples)
- ideal_n_groups_per_fold = n_samples // n_splits
- len(np.unique(groups))
- # Get the test fold indices from the test set indices of each fold
- folds = np.zeros(n_samples)
- lkf = kfold(n_splits=n_splits)
- for i, (_, test) in enumerate(lkf.split(X, y, groups)):
- folds[test] = i
- # Check that folds have approximately the same size
- assert len(folds) == len(groups)
- for i in np.unique(folds):
- assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
- # Check that each group appears only in 1 fold
- for group in np.unique(groups):
- assert len(np.unique(folds[groups == group])) == 1
- # Check that no group is on both sides of the split
- groups = np.asarray(groups, dtype=object)
- for train, test in lkf.split(X, y, groups):
- assert len(np.intersect1d(groups[train], groups[test])) == 0
- # Construct the test data
- groups = np.array(
- [
- "Albert",
- "Jean",
- "Bertrand",
- "Michel",
- "Jean",
- "Francis",
- "Robert",
- "Michel",
- "Rachel",
- "Lois",
- "Michelle",
- "Bernard",
- "Marion",
- "Laura",
- "Jean",
- "Rachel",
- "Franck",
- "John",
- "Gael",
- "Anna",
- "Alix",
- "Robert",
- "Marion",
- "David",
- "Tony",
- "Abel",
- "Becky",
- "Madmood",
- "Cary",
- "Mary",
- "Alexandre",
- "David",
- "Francis",
- "Barack",
- "Abdoul",
- "Rasha",
- "Xi",
- "Silvia",
- ]
- )
- n_groups = len(np.unique(groups))
- n_samples = len(groups)
- n_splits = 5
- tolerance = 0.05 * n_samples # 5 percent error allowed
- ideal_n_groups_per_fold = n_samples // n_splits
- X = y = np.ones(n_samples)
- # Get the test fold indices from the test set indices of each fold
- folds = np.zeros(n_samples)
- for i, (_, test) in enumerate(lkf.split(X, y, groups)):
- folds[test] = i
- # Check that folds have approximately the same size
- assert len(folds) == len(groups)
- for i in np.unique(folds):
- assert tolerance >= abs(sum(folds == i) - ideal_n_groups_per_fold)
- # Check that each group appears only in 1 fold
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", FutureWarning)
- for group in np.unique(groups):
- assert len(np.unique(folds[groups == group])) == 1
- # Check that no group is on both sides of the split
- groups = np.asarray(groups, dtype=object)
- for train, test in lkf.split(X, y, groups):
- assert len(np.intersect1d(groups[train], groups[test])) == 0
- # groups can also be a list
- cv_iter = list(lkf.split(X, y, groups.tolist()))
- for (train1, test1), (train2, test2) in zip(lkf.split(X, y, groups), cv_iter):
- assert_array_equal(train1, train2)
- assert_array_equal(test1, test2)
- # Should fail if there are more folds than groups
- groups = np.array([1, 1, 1, 2, 2])
- X = y = np.ones(len(groups))
- with pytest.raises(ValueError, match="Cannot have number of splits.*greater"):
- next(GroupKFold(n_splits=3).split(X, y, groups))
- def test_time_series_cv():
- X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
- # Should fail if there are more folds than samples
- with pytest.raises(ValueError, match="Cannot have number of folds.*greater"):
- next(TimeSeriesSplit(n_splits=7).split(X))
- tscv = TimeSeriesSplit(2)
- # Manually check that Time Series CV preserves the data
- # ordering on toy datasets
- splits = tscv.split(X[:-1])
- train, test = next(splits)
- assert_array_equal(train, [0, 1])
- assert_array_equal(test, [2, 3])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3])
- assert_array_equal(test, [4, 5])
- splits = TimeSeriesSplit(2).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2])
- assert_array_equal(test, [3, 4])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3, 4])
- assert_array_equal(test, [5, 6])
- # Check get_n_splits returns the correct number of splits
- splits = TimeSeriesSplit(2).split(X)
- n_splits_actual = len(list(splits))
- assert n_splits_actual == tscv.get_n_splits()
- assert n_splits_actual == 2
- def _check_time_series_max_train_size(splits, check_splits, max_train_size):
- for (train, test), (check_train, check_test) in zip(splits, check_splits):
- assert_array_equal(test, check_test)
- assert len(check_train) <= max_train_size
- suffix_start = max(len(train) - max_train_size, 0)
- assert_array_equal(check_train, train[suffix_start:])
- def test_time_series_max_train_size():
- X = np.zeros((6, 1))
- splits = TimeSeriesSplit(n_splits=3).split(X)
- check_splits = TimeSeriesSplit(n_splits=3, max_train_size=3).split(X)
- _check_time_series_max_train_size(splits, check_splits, max_train_size=3)
- # Test for the case where the size of a fold is greater than max_train_size
- check_splits = TimeSeriesSplit(n_splits=3, max_train_size=2).split(X)
- _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
- # Test for the case where the size of each fold is less than max_train_size
- check_splits = TimeSeriesSplit(n_splits=3, max_train_size=5).split(X)
- _check_time_series_max_train_size(splits, check_splits, max_train_size=2)
- def test_time_series_test_size():
- X = np.zeros((10, 1))
- # Test alone
- splits = TimeSeriesSplit(n_splits=3, test_size=3).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0])
- assert_array_equal(test, [1, 2, 3])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3])
- assert_array_equal(test, [4, 5, 6])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6])
- assert_array_equal(test, [7, 8, 9])
- # Test with max_train_size
- splits = TimeSeriesSplit(n_splits=2, test_size=2, max_train_size=4).split(X)
- train, test = next(splits)
- assert_array_equal(train, [2, 3, 4, 5])
- assert_array_equal(test, [6, 7])
- train, test = next(splits)
- assert_array_equal(train, [4, 5, 6, 7])
- assert_array_equal(test, [8, 9])
- # Should fail with not enough data points for configuration
- with pytest.raises(ValueError, match="Too many splits.*with test_size"):
- splits = TimeSeriesSplit(n_splits=5, test_size=2).split(X)
- next(splits)
- def test_time_series_gap():
- X = np.zeros((10, 1))
- # Test alone
- splits = TimeSeriesSplit(n_splits=2, gap=2).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0, 1])
- assert_array_equal(test, [4, 5, 6])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3, 4])
- assert_array_equal(test, [7, 8, 9])
- # Test with max_train_size
- splits = TimeSeriesSplit(n_splits=3, gap=2, max_train_size=2).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0, 1])
- assert_array_equal(test, [4, 5])
- train, test = next(splits)
- assert_array_equal(train, [2, 3])
- assert_array_equal(test, [6, 7])
- train, test = next(splits)
- assert_array_equal(train, [4, 5])
- assert_array_equal(test, [8, 9])
- # Test with test_size
- splits = TimeSeriesSplit(n_splits=2, gap=2, max_train_size=4, test_size=2).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3])
- assert_array_equal(test, [6, 7])
- train, test = next(splits)
- assert_array_equal(train, [2, 3, 4, 5])
- assert_array_equal(test, [8, 9])
- # Test with additional test_size
- splits = TimeSeriesSplit(n_splits=2, gap=2, test_size=3).split(X)
- train, test = next(splits)
- assert_array_equal(train, [0, 1])
- assert_array_equal(test, [4, 5, 6])
- train, test = next(splits)
- assert_array_equal(train, [0, 1, 2, 3, 4])
- assert_array_equal(test, [7, 8, 9])
- # Verify proper error is thrown
- with pytest.raises(ValueError, match="Too many splits.*and gap"):
- splits = TimeSeriesSplit(n_splits=4, gap=2).split(X)
- next(splits)
- def test_nested_cv():
- # Test if nested cross validation works with different combinations of cv
- rng = np.random.RandomState(0)
- X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
- groups = rng.randint(0, 5, 15)
- cvs = [
- LeaveOneGroupOut(),
- StratifiedKFold(n_splits=2),
- LeaveOneOut(),
- GroupKFold(n_splits=3),
- StratifiedKFold(),
- StratifiedGroupKFold(),
- StratifiedShuffleSplit(n_splits=3, random_state=0),
- ]
- for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
- gs = GridSearchCV(
- DummyClassifier(),
- param_grid={"strategy": ["stratified", "most_frequent"]},
- cv=inner_cv,
- error_score="raise",
- )
- cross_val_score(
- gs, X=X, y=y, groups=groups, cv=outer_cv, fit_params={"groups": groups}
- )
- def test_build_repr():
- class MockSplitter:
- def __init__(self, a, b=0, c=None):
- self.a = a
- self.b = b
- self.c = c
- def __repr__(self):
- return _build_repr(self)
- assert repr(MockSplitter(5, 6)) == "MockSplitter(a=5, b=6, c=None)"
- @pytest.mark.parametrize(
- "CVSplitter", (ShuffleSplit, GroupShuffleSplit, StratifiedShuffleSplit)
- )
- def test_shuffle_split_empty_trainset(CVSplitter):
- cv = CVSplitter(test_size=0.99)
- X, y = [[1]], [0] # 1 sample
- with pytest.raises(
- ValueError,
- match=(
- "With n_samples=1, test_size=0.99 and train_size=None, "
- "the resulting train set will be empty"
- ),
- ):
- next(cv.split(X, y, groups=[1]))
- def test_train_test_split_empty_trainset():
- (X,) = [[1]] # 1 sample
- with pytest.raises(
- ValueError,
- match=(
- "With n_samples=1, test_size=0.99 and train_size=None, "
- "the resulting train set will be empty"
- ),
- ):
- train_test_split(X, test_size=0.99)
- X = [[1], [1], [1]] # 3 samples, ask for more than 2 thirds
- with pytest.raises(
- ValueError,
- match=(
- "With n_samples=3, test_size=0.67 and train_size=None, "
- "the resulting train set will be empty"
- ),
- ):
- train_test_split(X, test_size=0.67)
- def test_leave_one_out_empty_trainset():
- # LeaveOneGroup out expect at least 2 groups so no need to check
- cv = LeaveOneOut()
- X, y = [[1]], [0] # 1 sample
- with pytest.raises(ValueError, match="Cannot perform LeaveOneOut with n_samples=1"):
- next(cv.split(X, y))
- def test_leave_p_out_empty_trainset():
- # No need to check LeavePGroupsOut
- cv = LeavePOut(p=2)
- X, y = [[1], [2]], [0, 3] # 2 samples
- with pytest.raises(
- ValueError, match="p=2 must be strictly less than the number of samples=2"
- ):
- next(cv.split(X, y, groups=[1, 2]))
- @pytest.mark.parametrize("Klass", (KFold, StratifiedKFold, StratifiedGroupKFold))
- def test_random_state_shuffle_false(Klass):
- # passing a non-default random_state when shuffle=False makes no sense
- with pytest.raises(ValueError, match="has no effect since shuffle is False"):
- Klass(3, shuffle=False, random_state=0)
- @pytest.mark.parametrize(
- "cv, expected",
- [
- (KFold(), True),
- (KFold(shuffle=True, random_state=123), True),
- (StratifiedKFold(), True),
- (StratifiedKFold(shuffle=True, random_state=123), True),
- (StratifiedGroupKFold(shuffle=True, random_state=123), True),
- (StratifiedGroupKFold(), True),
- (RepeatedKFold(random_state=123), True),
- (RepeatedStratifiedKFold(random_state=123), True),
- (ShuffleSplit(random_state=123), True),
- (GroupShuffleSplit(random_state=123), True),
- (StratifiedShuffleSplit(random_state=123), True),
- (GroupKFold(), True),
- (TimeSeriesSplit(), True),
- (LeaveOneOut(), True),
- (LeaveOneGroupOut(), True),
- (LeavePGroupsOut(n_groups=2), True),
- (LeavePOut(p=2), True),
- (KFold(shuffle=True, random_state=None), False),
- (KFold(shuffle=True, random_state=None), False),
- (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
- (StratifiedKFold(shuffle=True, random_state=np.random.RandomState(0)), False),
- (RepeatedKFold(random_state=None), False),
- (RepeatedKFold(random_state=np.random.RandomState(0)), False),
- (RepeatedStratifiedKFold(random_state=None), False),
- (RepeatedStratifiedKFold(random_state=np.random.RandomState(0)), False),
- (ShuffleSplit(random_state=None), False),
- (ShuffleSplit(random_state=np.random.RandomState(0)), False),
- (GroupShuffleSplit(random_state=None), False),
- (GroupShuffleSplit(random_state=np.random.RandomState(0)), False),
- (StratifiedShuffleSplit(random_state=None), False),
- (StratifiedShuffleSplit(random_state=np.random.RandomState(0)), False),
- ],
- )
- def test_yields_constant_splits(cv, expected):
- assert _yields_constant_splits(cv) == expected
- @pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS])
- def test_splitter_get_metadata_routing(cv):
- """Check get_metadata_routing returns the correct MetadataRouter."""
- assert hasattr(cv, "get_metadata_routing")
- metadata = cv.get_metadata_routing()
- if cv in GROUP_SPLITTERS:
- assert metadata.split.requests["groups"] is True
- elif cv in NO_GROUP_SPLITTERS:
- assert not metadata.split.requests
- assert_request_is_empty(metadata, exclude=["split"])
- @pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS])
- def test_splitter_set_split_request(cv):
- """Check set_split_request is defined for group splitters and not for others."""
- if cv in GROUP_SPLITTERS:
- assert hasattr(cv, "set_split_request")
- elif cv in NO_GROUP_SPLITTERS:
- assert not hasattr(cv, "set_split_request")
|