_split.py 96 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783
  1. """
  2. The :mod:`sklearn.model_selection._split` module includes classes and
  3. functions to split the data based on a preset strategy.
  4. """
  5. # Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>
  6. # Gael Varoquaux <gael.varoquaux@normalesup.org>
  7. # Olivier Grisel <olivier.grisel@ensta.org>
  8. # Raghav RV <rvraghav93@gmail.com>
  9. # Leandro Hermida <hermidal@cs.umd.edu>
  10. # Rodion Martynov <marrodion@gmail.com>
  11. # License: BSD 3 clause
  12. import numbers
  13. import warnings
  14. from abc import ABCMeta, abstractmethod
  15. from collections import defaultdict
  16. from collections.abc import Iterable
  17. from inspect import signature
  18. from itertools import chain, combinations
  19. from math import ceil, floor
  20. import numpy as np
  21. from scipy.special import comb
  22. from ..utils import (
  23. _approximate_mode,
  24. _safe_indexing,
  25. check_random_state,
  26. indexable,
  27. metadata_routing,
  28. )
  29. from ..utils._param_validation import Interval, RealNotInt, validate_params
  30. from ..utils.metadata_routing import _MetadataRequester
  31. from ..utils.multiclass import type_of_target
  32. from ..utils.validation import _num_samples, check_array, column_or_1d
  33. __all__ = [
  34. "BaseCrossValidator",
  35. "KFold",
  36. "GroupKFold",
  37. "LeaveOneGroupOut",
  38. "LeaveOneOut",
  39. "LeavePGroupsOut",
  40. "LeavePOut",
  41. "RepeatedStratifiedKFold",
  42. "RepeatedKFold",
  43. "ShuffleSplit",
  44. "GroupShuffleSplit",
  45. "StratifiedKFold",
  46. "StratifiedGroupKFold",
  47. "StratifiedShuffleSplit",
  48. "PredefinedSplit",
  49. "train_test_split",
  50. "check_cv",
  51. ]
  52. class GroupsConsumerMixin(_MetadataRequester):
  53. """A Mixin to ``groups`` by default.
  54. This Mixin makes the object to request ``groups`` by default as ``True``.
  55. .. versionadded:: 1.3
  56. """
  57. __metadata_request__split = {"groups": True}
  58. class BaseCrossValidator(_MetadataRequester, metaclass=ABCMeta):
  59. """Base class for all cross-validators
  60. Implementations must define `_iter_test_masks` or `_iter_test_indices`.
  61. """
  62. # This indicates that by default CV splitters don't have a "groups" kwarg,
  63. # unless indicated by inheriting from ``GroupsConsumerMixin``.
  64. # This also prevents ``set_split_request`` to be generated for splitters
  65. # which don't support ``groups``.
  66. __metadata_request__split = {"groups": metadata_routing.UNUSED}
  67. def split(self, X, y=None, groups=None):
  68. """Generate indices to split data into training and test set.
  69. Parameters
  70. ----------
  71. X : array-like of shape (n_samples, n_features)
  72. Training data, where `n_samples` is the number of samples
  73. and `n_features` is the number of features.
  74. y : array-like of shape (n_samples,)
  75. The target variable for supervised learning problems.
  76. groups : array-like of shape (n_samples,), default=None
  77. Group labels for the samples used while splitting the dataset into
  78. train/test set.
  79. Yields
  80. ------
  81. train : ndarray
  82. The training set indices for that split.
  83. test : ndarray
  84. The testing set indices for that split.
  85. """
  86. X, y, groups = indexable(X, y, groups)
  87. indices = np.arange(_num_samples(X))
  88. for test_index in self._iter_test_masks(X, y, groups):
  89. train_index = indices[np.logical_not(test_index)]
  90. test_index = indices[test_index]
  91. yield train_index, test_index
  92. # Since subclasses must implement either _iter_test_masks or
  93. # _iter_test_indices, neither can be abstract.
  94. def _iter_test_masks(self, X=None, y=None, groups=None):
  95. """Generates boolean masks corresponding to test sets.
  96. By default, delegates to _iter_test_indices(X, y, groups)
  97. """
  98. for test_index in self._iter_test_indices(X, y, groups):
  99. test_mask = np.zeros(_num_samples(X), dtype=bool)
  100. test_mask[test_index] = True
  101. yield test_mask
  102. def _iter_test_indices(self, X=None, y=None, groups=None):
  103. """Generates integer indices corresponding to test sets."""
  104. raise NotImplementedError
  105. @abstractmethod
  106. def get_n_splits(self, X=None, y=None, groups=None):
  107. """Returns the number of splitting iterations in the cross-validator"""
  108. def __repr__(self):
  109. return _build_repr(self)
  110. class LeaveOneOut(BaseCrossValidator):
  111. """Leave-One-Out cross-validator
  112. Provides train/test indices to split data in train/test sets. Each
  113. sample is used once as a test set (singleton) while the remaining
  114. samples form the training set.
  115. Note: ``LeaveOneOut()`` is equivalent to ``KFold(n_splits=n)`` and
  116. ``LeavePOut(p=1)`` where ``n`` is the number of samples.
  117. Due to the high number of test sets (which is the same as the
  118. number of samples) this cross-validation method can be very costly.
  119. For large datasets one should favor :class:`KFold`, :class:`ShuffleSplit`
  120. or :class:`StratifiedKFold`.
  121. Read more in the :ref:`User Guide <leave_one_out>`.
  122. Examples
  123. --------
  124. >>> import numpy as np
  125. >>> from sklearn.model_selection import LeaveOneOut
  126. >>> X = np.array([[1, 2], [3, 4]])
  127. >>> y = np.array([1, 2])
  128. >>> loo = LeaveOneOut()
  129. >>> loo.get_n_splits(X)
  130. 2
  131. >>> print(loo)
  132. LeaveOneOut()
  133. >>> for i, (train_index, test_index) in enumerate(loo.split(X)):
  134. ... print(f"Fold {i}:")
  135. ... print(f" Train: index={train_index}")
  136. ... print(f" Test: index={test_index}")
  137. Fold 0:
  138. Train: index=[1]
  139. Test: index=[0]
  140. Fold 1:
  141. Train: index=[0]
  142. Test: index=[1]
  143. See Also
  144. --------
  145. LeaveOneGroupOut : For splitting the data according to explicit,
  146. domain-specific stratification of the dataset.
  147. GroupKFold : K-fold iterator variant with non-overlapping groups.
  148. """
  149. def _iter_test_indices(self, X, y=None, groups=None):
  150. n_samples = _num_samples(X)
  151. if n_samples <= 1:
  152. raise ValueError(
  153. "Cannot perform LeaveOneOut with n_samples={}.".format(n_samples)
  154. )
  155. return range(n_samples)
  156. def get_n_splits(self, X, y=None, groups=None):
  157. """Returns the number of splitting iterations in the cross-validator
  158. Parameters
  159. ----------
  160. X : array-like of shape (n_samples, n_features)
  161. Training data, where `n_samples` is the number of samples
  162. and `n_features` is the number of features.
  163. y : object
  164. Always ignored, exists for compatibility.
  165. groups : object
  166. Always ignored, exists for compatibility.
  167. Returns
  168. -------
  169. n_splits : int
  170. Returns the number of splitting iterations in the cross-validator.
  171. """
  172. if X is None:
  173. raise ValueError("The 'X' parameter should not be None.")
  174. return _num_samples(X)
  175. class LeavePOut(BaseCrossValidator):
  176. """Leave-P-Out cross-validator
  177. Provides train/test indices to split data in train/test sets. This results
  178. in testing on all distinct samples of size p, while the remaining n - p
  179. samples form the training set in each iteration.
  180. Note: ``LeavePOut(p)`` is NOT equivalent to
  181. ``KFold(n_splits=n_samples // p)`` which creates non-overlapping test sets.
  182. Due to the high number of iterations which grows combinatorically with the
  183. number of samples this cross-validation method can be very costly. For
  184. large datasets one should favor :class:`KFold`, :class:`StratifiedKFold`
  185. or :class:`ShuffleSplit`.
  186. Read more in the :ref:`User Guide <leave_p_out>`.
  187. Parameters
  188. ----------
  189. p : int
  190. Size of the test sets. Must be strictly less than the number of
  191. samples.
  192. Examples
  193. --------
  194. >>> import numpy as np
  195. >>> from sklearn.model_selection import LeavePOut
  196. >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
  197. >>> y = np.array([1, 2, 3, 4])
  198. >>> lpo = LeavePOut(2)
  199. >>> lpo.get_n_splits(X)
  200. 6
  201. >>> print(lpo)
  202. LeavePOut(p=2)
  203. >>> for i, (train_index, test_index) in enumerate(lpo.split(X)):
  204. ... print(f"Fold {i}:")
  205. ... print(f" Train: index={train_index}")
  206. ... print(f" Test: index={test_index}")
  207. Fold 0:
  208. Train: index=[2 3]
  209. Test: index=[0 1]
  210. Fold 1:
  211. Train: index=[1 3]
  212. Test: index=[0 2]
  213. Fold 2:
  214. Train: index=[1 2]
  215. Test: index=[0 3]
  216. Fold 3:
  217. Train: index=[0 3]
  218. Test: index=[1 2]
  219. Fold 4:
  220. Train: index=[0 2]
  221. Test: index=[1 3]
  222. Fold 5:
  223. Train: index=[0 1]
  224. Test: index=[2 3]
  225. """
  226. def __init__(self, p):
  227. self.p = p
  228. def _iter_test_indices(self, X, y=None, groups=None):
  229. n_samples = _num_samples(X)
  230. if n_samples <= self.p:
  231. raise ValueError(
  232. "p={} must be strictly less than the number of samples={}".format(
  233. self.p, n_samples
  234. )
  235. )
  236. for combination in combinations(range(n_samples), self.p):
  237. yield np.array(combination)
  238. def get_n_splits(self, X, y=None, groups=None):
  239. """Returns the number of splitting iterations in the cross-validator
  240. Parameters
  241. ----------
  242. X : array-like of shape (n_samples, n_features)
  243. Training data, where `n_samples` is the number of samples
  244. and `n_features` is the number of features.
  245. y : object
  246. Always ignored, exists for compatibility.
  247. groups : object
  248. Always ignored, exists for compatibility.
  249. """
  250. if X is None:
  251. raise ValueError("The 'X' parameter should not be None.")
  252. return int(comb(_num_samples(X), self.p, exact=True))
  253. class _BaseKFold(BaseCrossValidator, metaclass=ABCMeta):
  254. """Base class for KFold, GroupKFold, and StratifiedKFold"""
  255. @abstractmethod
  256. def __init__(self, n_splits, *, shuffle, random_state):
  257. if not isinstance(n_splits, numbers.Integral):
  258. raise ValueError(
  259. "The number of folds must be of Integral type. "
  260. "%s of type %s was passed." % (n_splits, type(n_splits))
  261. )
  262. n_splits = int(n_splits)
  263. if n_splits <= 1:
  264. raise ValueError(
  265. "k-fold cross-validation requires at least one"
  266. " train/test split by setting n_splits=2 or more,"
  267. " got n_splits={0}.".format(n_splits)
  268. )
  269. if not isinstance(shuffle, bool):
  270. raise TypeError("shuffle must be True or False; got {0}".format(shuffle))
  271. if not shuffle and random_state is not None: # None is the default
  272. raise ValueError(
  273. (
  274. "Setting a random_state has no effect since shuffle is "
  275. "False. You should leave "
  276. "random_state to its default (None), or set shuffle=True."
  277. ),
  278. )
  279. self.n_splits = n_splits
  280. self.shuffle = shuffle
  281. self.random_state = random_state
  282. def split(self, X, y=None, groups=None):
  283. """Generate indices to split data into training and test set.
  284. Parameters
  285. ----------
  286. X : array-like of shape (n_samples, n_features)
  287. Training data, where `n_samples` is the number of samples
  288. and `n_features` is the number of features.
  289. y : array-like of shape (n_samples,), default=None
  290. The target variable for supervised learning problems.
  291. groups : array-like of shape (n_samples,), default=None
  292. Group labels for the samples used while splitting the dataset into
  293. train/test set.
  294. Yields
  295. ------
  296. train : ndarray
  297. The training set indices for that split.
  298. test : ndarray
  299. The testing set indices for that split.
  300. """
  301. X, y, groups = indexable(X, y, groups)
  302. n_samples = _num_samples(X)
  303. if self.n_splits > n_samples:
  304. raise ValueError(
  305. (
  306. "Cannot have number of splits n_splits={0} greater"
  307. " than the number of samples: n_samples={1}."
  308. ).format(self.n_splits, n_samples)
  309. )
  310. for train, test in super().split(X, y, groups):
  311. yield train, test
  312. def get_n_splits(self, X=None, y=None, groups=None):
  313. """Returns the number of splitting iterations in the cross-validator
  314. Parameters
  315. ----------
  316. X : object
  317. Always ignored, exists for compatibility.
  318. y : object
  319. Always ignored, exists for compatibility.
  320. groups : object
  321. Always ignored, exists for compatibility.
  322. Returns
  323. -------
  324. n_splits : int
  325. Returns the number of splitting iterations in the cross-validator.
  326. """
  327. return self.n_splits
  328. class KFold(_BaseKFold):
  329. """K-Folds cross-validator
  330. Provides train/test indices to split data in train/test sets. Split
  331. dataset into k consecutive folds (without shuffling by default).
  332. Each fold is then used once as a validation while the k - 1 remaining
  333. folds form the training set.
  334. Read more in the :ref:`User Guide <k_fold>`.
  335. For visualisation of cross-validation behaviour and
  336. comparison between common scikit-learn split methods
  337. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  338. Parameters
  339. ----------
  340. n_splits : int, default=5
  341. Number of folds. Must be at least 2.
  342. .. versionchanged:: 0.22
  343. ``n_splits`` default value changed from 3 to 5.
  344. shuffle : bool, default=False
  345. Whether to shuffle the data before splitting into batches.
  346. Note that the samples within each split will not be shuffled.
  347. random_state : int, RandomState instance or None, default=None
  348. When `shuffle` is True, `random_state` affects the ordering of the
  349. indices, which controls the randomness of each fold. Otherwise, this
  350. parameter has no effect.
  351. Pass an int for reproducible output across multiple function calls.
  352. See :term:`Glossary <random_state>`.
  353. Examples
  354. --------
  355. >>> import numpy as np
  356. >>> from sklearn.model_selection import KFold
  357. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
  358. >>> y = np.array([1, 2, 3, 4])
  359. >>> kf = KFold(n_splits=2)
  360. >>> kf.get_n_splits(X)
  361. 2
  362. >>> print(kf)
  363. KFold(n_splits=2, random_state=None, shuffle=False)
  364. >>> for i, (train_index, test_index) in enumerate(kf.split(X)):
  365. ... print(f"Fold {i}:")
  366. ... print(f" Train: index={train_index}")
  367. ... print(f" Test: index={test_index}")
  368. Fold 0:
  369. Train: index=[2 3]
  370. Test: index=[0 1]
  371. Fold 1:
  372. Train: index=[0 1]
  373. Test: index=[2 3]
  374. Notes
  375. -----
  376. The first ``n_samples % n_splits`` folds have size
  377. ``n_samples // n_splits + 1``, other folds have size
  378. ``n_samples // n_splits``, where ``n_samples`` is the number of samples.
  379. Randomized CV splitters may return different results for each call of
  380. split. You can make the results identical by setting `random_state`
  381. to an integer.
  382. See Also
  383. --------
  384. StratifiedKFold : Takes class information into account to avoid building
  385. folds with imbalanced class distributions (for binary or multiclass
  386. classification tasks).
  387. GroupKFold : K-fold iterator variant with non-overlapping groups.
  388. RepeatedKFold : Repeats K-Fold n times.
  389. """
  390. def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
  391. super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
  392. def _iter_test_indices(self, X, y=None, groups=None):
  393. n_samples = _num_samples(X)
  394. indices = np.arange(n_samples)
  395. if self.shuffle:
  396. check_random_state(self.random_state).shuffle(indices)
  397. n_splits = self.n_splits
  398. fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=int)
  399. fold_sizes[: n_samples % n_splits] += 1
  400. current = 0
  401. for fold_size in fold_sizes:
  402. start, stop = current, current + fold_size
  403. yield indices[start:stop]
  404. current = stop
  405. class GroupKFold(GroupsConsumerMixin, _BaseKFold):
  406. """K-fold iterator variant with non-overlapping groups.
  407. Each group will appear exactly once in the test set across all folds (the
  408. number of distinct groups has to be at least equal to the number of folds).
  409. The folds are approximately balanced in the sense that the number of
  410. distinct groups is approximately the same in each fold.
  411. Read more in the :ref:`User Guide <group_k_fold>`.
  412. For visualisation of cross-validation behaviour and
  413. comparison between common scikit-learn split methods
  414. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  415. Parameters
  416. ----------
  417. n_splits : int, default=5
  418. Number of folds. Must be at least 2.
  419. .. versionchanged:: 0.22
  420. ``n_splits`` default value changed from 3 to 5.
  421. Notes
  422. -----
  423. Groups appear in an arbitrary order throughout the folds.
  424. Examples
  425. --------
  426. >>> import numpy as np
  427. >>> from sklearn.model_selection import GroupKFold
  428. >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
  429. >>> y = np.array([1, 2, 3, 4, 5, 6])
  430. >>> groups = np.array([0, 0, 2, 2, 3, 3])
  431. >>> group_kfold = GroupKFold(n_splits=2)
  432. >>> group_kfold.get_n_splits(X, y, groups)
  433. 2
  434. >>> print(group_kfold)
  435. GroupKFold(n_splits=2)
  436. >>> for i, (train_index, test_index) in enumerate(group_kfold.split(X, y, groups)):
  437. ... print(f"Fold {i}:")
  438. ... print(f" Train: index={train_index}, group={groups[train_index]}")
  439. ... print(f" Test: index={test_index}, group={groups[test_index]}")
  440. Fold 0:
  441. Train: index=[2 3], group=[2 2]
  442. Test: index=[0 1 4 5], group=[0 0 3 3]
  443. Fold 1:
  444. Train: index=[0 1 4 5], group=[0 0 3 3]
  445. Test: index=[2 3], group=[2 2]
  446. See Also
  447. --------
  448. LeaveOneGroupOut : For splitting the data according to explicit
  449. domain-specific stratification of the dataset.
  450. StratifiedKFold : Takes class information into account to avoid building
  451. folds with imbalanced class proportions (for binary or multiclass
  452. classification tasks).
  453. """
  454. def __init__(self, n_splits=5):
  455. super().__init__(n_splits, shuffle=False, random_state=None)
  456. def _iter_test_indices(self, X, y, groups):
  457. if groups is None:
  458. raise ValueError("The 'groups' parameter should not be None.")
  459. groups = check_array(groups, input_name="groups", ensure_2d=False, dtype=None)
  460. unique_groups, groups = np.unique(groups, return_inverse=True)
  461. n_groups = len(unique_groups)
  462. if self.n_splits > n_groups:
  463. raise ValueError(
  464. "Cannot have number of splits n_splits=%d greater"
  465. " than the number of groups: %d." % (self.n_splits, n_groups)
  466. )
  467. # Weight groups by their number of occurrences
  468. n_samples_per_group = np.bincount(groups)
  469. # Distribute the most frequent groups first
  470. indices = np.argsort(n_samples_per_group)[::-1]
  471. n_samples_per_group = n_samples_per_group[indices]
  472. # Total weight of each fold
  473. n_samples_per_fold = np.zeros(self.n_splits)
  474. # Mapping from group index to fold index
  475. group_to_fold = np.zeros(len(unique_groups))
  476. # Distribute samples by adding the largest weight to the lightest fold
  477. for group_index, weight in enumerate(n_samples_per_group):
  478. lightest_fold = np.argmin(n_samples_per_fold)
  479. n_samples_per_fold[lightest_fold] += weight
  480. group_to_fold[indices[group_index]] = lightest_fold
  481. indices = group_to_fold[groups]
  482. for f in range(self.n_splits):
  483. yield np.where(indices == f)[0]
  484. def split(self, X, y=None, groups=None):
  485. """Generate indices to split data into training and test set.
  486. Parameters
  487. ----------
  488. X : array-like of shape (n_samples, n_features)
  489. Training data, where `n_samples` is the number of samples
  490. and `n_features` is the number of features.
  491. y : array-like of shape (n_samples,), default=None
  492. The target variable for supervised learning problems.
  493. groups : array-like of shape (n_samples,)
  494. Group labels for the samples used while splitting the dataset into
  495. train/test set.
  496. Yields
  497. ------
  498. train : ndarray
  499. The training set indices for that split.
  500. test : ndarray
  501. The testing set indices for that split.
  502. """
  503. return super().split(X, y, groups)
  504. class StratifiedKFold(_BaseKFold):
  505. """Stratified K-Folds cross-validator.
  506. Provides train/test indices to split data in train/test sets.
  507. This cross-validation object is a variation of KFold that returns
  508. stratified folds. The folds are made by preserving the percentage of
  509. samples for each class.
  510. Read more in the :ref:`User Guide <stratified_k_fold>`.
  511. For visualisation of cross-validation behaviour and
  512. comparison between common scikit-learn split methods
  513. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  514. Parameters
  515. ----------
  516. n_splits : int, default=5
  517. Number of folds. Must be at least 2.
  518. .. versionchanged:: 0.22
  519. ``n_splits`` default value changed from 3 to 5.
  520. shuffle : bool, default=False
  521. Whether to shuffle each class's samples before splitting into batches.
  522. Note that the samples within each split will not be shuffled.
  523. random_state : int, RandomState instance or None, default=None
  524. When `shuffle` is True, `random_state` affects the ordering of the
  525. indices, which controls the randomness of each fold for each class.
  526. Otherwise, leave `random_state` as `None`.
  527. Pass an int for reproducible output across multiple function calls.
  528. See :term:`Glossary <random_state>`.
  529. Examples
  530. --------
  531. >>> import numpy as np
  532. >>> from sklearn.model_selection import StratifiedKFold
  533. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
  534. >>> y = np.array([0, 0, 1, 1])
  535. >>> skf = StratifiedKFold(n_splits=2)
  536. >>> skf.get_n_splits(X, y)
  537. 2
  538. >>> print(skf)
  539. StratifiedKFold(n_splits=2, random_state=None, shuffle=False)
  540. >>> for i, (train_index, test_index) in enumerate(skf.split(X, y)):
  541. ... print(f"Fold {i}:")
  542. ... print(f" Train: index={train_index}")
  543. ... print(f" Test: index={test_index}")
  544. Fold 0:
  545. Train: index=[1 3]
  546. Test: index=[0 2]
  547. Fold 1:
  548. Train: index=[0 2]
  549. Test: index=[1 3]
  550. Notes
  551. -----
  552. The implementation is designed to:
  553. * Generate test sets such that all contain the same distribution of
  554. classes, or as close as possible.
  555. * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
  556. ``y = [1, 0]`` should not change the indices generated.
  557. * Preserve order dependencies in the dataset ordering, when
  558. ``shuffle=False``: all samples from class k in some test set were
  559. contiguous in y, or separated in y by samples from classes other than k.
  560. * Generate test sets where the smallest and largest differ by at most one
  561. sample.
  562. .. versionchanged:: 0.22
  563. The previous implementation did not follow the last constraint.
  564. See Also
  565. --------
  566. RepeatedStratifiedKFold : Repeats Stratified K-Fold n times.
  567. """
  568. def __init__(self, n_splits=5, *, shuffle=False, random_state=None):
  569. super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
  570. def _make_test_folds(self, X, y=None):
  571. rng = check_random_state(self.random_state)
  572. y = np.asarray(y)
  573. type_of_target_y = type_of_target(y)
  574. allowed_target_types = ("binary", "multiclass")
  575. if type_of_target_y not in allowed_target_types:
  576. raise ValueError(
  577. "Supported target types are: {}. Got {!r} instead.".format(
  578. allowed_target_types, type_of_target_y
  579. )
  580. )
  581. y = column_or_1d(y)
  582. _, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True)
  583. # y_inv encodes y according to lexicographic order. We invert y_idx to
  584. # map the classes so that they are encoded by order of appearance:
  585. # 0 represents the first label appearing in y, 1 the second, etc.
  586. _, class_perm = np.unique(y_idx, return_inverse=True)
  587. y_encoded = class_perm[y_inv]
  588. n_classes = len(y_idx)
  589. y_counts = np.bincount(y_encoded)
  590. min_groups = np.min(y_counts)
  591. if np.all(self.n_splits > y_counts):
  592. raise ValueError(
  593. "n_splits=%d cannot be greater than the"
  594. " number of members in each class." % (self.n_splits)
  595. )
  596. if self.n_splits > min_groups:
  597. warnings.warn(
  598. "The least populated class in y has only %d"
  599. " members, which is less than n_splits=%d."
  600. % (min_groups, self.n_splits),
  601. UserWarning,
  602. )
  603. # Determine the optimal number of samples from each class in each fold,
  604. # using round robin over the sorted y. (This can be done direct from
  605. # counts, but that code is unreadable.)
  606. y_order = np.sort(y_encoded)
  607. allocation = np.asarray(
  608. [
  609. np.bincount(y_order[i :: self.n_splits], minlength=n_classes)
  610. for i in range(self.n_splits)
  611. ]
  612. )
  613. # To maintain the data order dependencies as best as possible within
  614. # the stratification constraint, we assign samples from each class in
  615. # blocks (and then mess that up when shuffle=True).
  616. test_folds = np.empty(len(y), dtype="i")
  617. for k in range(n_classes):
  618. # since the kth column of allocation stores the number of samples
  619. # of class k in each test set, this generates blocks of fold
  620. # indices corresponding to the allocation for class k.
  621. folds_for_class = np.arange(self.n_splits).repeat(allocation[:, k])
  622. if self.shuffle:
  623. rng.shuffle(folds_for_class)
  624. test_folds[y_encoded == k] = folds_for_class
  625. return test_folds
  626. def _iter_test_masks(self, X, y=None, groups=None):
  627. test_folds = self._make_test_folds(X, y)
  628. for i in range(self.n_splits):
  629. yield test_folds == i
  630. def split(self, X, y, groups=None):
  631. """Generate indices to split data into training and test set.
  632. Parameters
  633. ----------
  634. X : array-like of shape (n_samples, n_features)
  635. Training data, where `n_samples` is the number of samples
  636. and `n_features` is the number of features.
  637. Note that providing ``y`` is sufficient to generate the splits and
  638. hence ``np.zeros(n_samples)`` may be used as a placeholder for
  639. ``X`` instead of actual training data.
  640. y : array-like of shape (n_samples,)
  641. The target variable for supervised learning problems.
  642. Stratification is done based on the y labels.
  643. groups : object
  644. Always ignored, exists for compatibility.
  645. Yields
  646. ------
  647. train : ndarray
  648. The training set indices for that split.
  649. test : ndarray
  650. The testing set indices for that split.
  651. Notes
  652. -----
  653. Randomized CV splitters may return different results for each call of
  654. split. You can make the results identical by setting `random_state`
  655. to an integer.
  656. """
  657. y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
  658. return super().split(X, y, groups)
  659. class StratifiedGroupKFold(GroupsConsumerMixin, _BaseKFold):
  660. """Stratified K-Folds iterator variant with non-overlapping groups.
  661. This cross-validation object is a variation of StratifiedKFold attempts to
  662. return stratified folds with non-overlapping groups. The folds are made by
  663. preserving the percentage of samples for each class.
  664. Each group will appear exactly once in the test set across all folds (the
  665. number of distinct groups has to be at least equal to the number of folds).
  666. The difference between :class:`~sklearn.model_selection.GroupKFold`
  667. and :class:`~sklearn.model_selection.StratifiedGroupKFold` is that
  668. the former attempts to create balanced folds such that the number of
  669. distinct groups is approximately the same in each fold, whereas
  670. StratifiedGroupKFold attempts to create folds which preserve the
  671. percentage of samples for each class as much as possible given the
  672. constraint of non-overlapping groups between splits.
  673. Read more in the :ref:`User Guide <cross_validation>`.
  674. For visualisation of cross-validation behaviour and
  675. comparison between common scikit-learn split methods
  676. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  677. Parameters
  678. ----------
  679. n_splits : int, default=5
  680. Number of folds. Must be at least 2.
  681. shuffle : bool, default=False
  682. Whether to shuffle each class's samples before splitting into batches.
  683. Note that the samples within each split will not be shuffled.
  684. This implementation can only shuffle groups that have approximately the
  685. same y distribution, no global shuffle will be performed.
  686. random_state : int or RandomState instance, default=None
  687. When `shuffle` is True, `random_state` affects the ordering of the
  688. indices, which controls the randomness of each fold for each class.
  689. Otherwise, leave `random_state` as `None`.
  690. Pass an int for reproducible output across multiple function calls.
  691. See :term:`Glossary <random_state>`.
  692. Examples
  693. --------
  694. >>> import numpy as np
  695. >>> from sklearn.model_selection import StratifiedGroupKFold
  696. >>> X = np.ones((17, 2))
  697. >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  698. >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
  699. >>> sgkf = StratifiedGroupKFold(n_splits=3)
  700. >>> sgkf.get_n_splits(X, y)
  701. 3
  702. >>> print(sgkf)
  703. StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False)
  704. >>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
  705. ... print(f"Fold {i}:")
  706. ... print(f" Train: index={train_index}")
  707. ... print(f" group={groups[train_index]}")
  708. ... print(f" Test: index={test_index}")
  709. ... print(f" group={groups[test_index]}")
  710. Fold 0:
  711. Train: index=[ 0 1 2 3 7 8 9 10 11 15 16]
  712. group=[1 1 2 2 4 5 5 5 5 8 8]
  713. Test: index=[ 4 5 6 12 13 14]
  714. group=[3 3 3 6 6 7]
  715. Fold 1:
  716. Train: index=[ 4 5 6 7 8 9 10 11 12 13 14]
  717. group=[3 3 3 4 5 5 5 5 6 6 7]
  718. Test: index=[ 0 1 2 3 15 16]
  719. group=[1 1 2 2 8 8]
  720. Fold 2:
  721. Train: index=[ 0 1 2 3 4 5 6 12 13 14 15 16]
  722. group=[1 1 2 2 3 3 3 6 6 7 8 8]
  723. Test: index=[ 7 8 9 10 11]
  724. group=[4 5 5 5 5]
  725. Notes
  726. -----
  727. The implementation is designed to:
  728. * Mimic the behavior of StratifiedKFold as much as possible for trivial
  729. groups (e.g. when each group contains only one sample).
  730. * Be invariant to class label: relabelling ``y = ["Happy", "Sad"]`` to
  731. ``y = [1, 0]`` should not change the indices generated.
  732. * Stratify based on samples as much as possible while keeping
  733. non-overlapping groups constraint. That means that in some cases when
  734. there is a small number of groups containing a large number of samples
  735. the stratification will not be possible and the behavior will be close
  736. to GroupKFold.
  737. See also
  738. --------
  739. StratifiedKFold: Takes class information into account to build folds which
  740. retain class distributions (for binary or multiclass classification
  741. tasks).
  742. GroupKFold: K-fold iterator variant with non-overlapping groups.
  743. """
  744. def __init__(self, n_splits=5, shuffle=False, random_state=None):
  745. super().__init__(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
  746. def _iter_test_indices(self, X, y, groups):
  747. # Implementation is based on this kaggle kernel:
  748. # https://www.kaggle.com/jakubwasikowski/stratified-group-k-fold-cross-validation
  749. # and is a subject to Apache 2.0 License. You may obtain a copy of the
  750. # License at http://www.apache.org/licenses/LICENSE-2.0
  751. # Changelist:
  752. # - Refactored function to a class following scikit-learn KFold
  753. # interface.
  754. # - Added heuristic for assigning group to the least populated fold in
  755. # cases when all other criteria are equal
  756. # - Swtch from using python ``Counter`` to ``np.unique`` to get class
  757. # distribution
  758. # - Added scikit-learn checks for input: checking that target is binary
  759. # or multiclass, checking passed random state, checking that number
  760. # of splits is less than number of members in each class, checking
  761. # that least populated class has more members than there are splits.
  762. rng = check_random_state(self.random_state)
  763. y = np.asarray(y)
  764. type_of_target_y = type_of_target(y)
  765. allowed_target_types = ("binary", "multiclass")
  766. if type_of_target_y not in allowed_target_types:
  767. raise ValueError(
  768. "Supported target types are: {}. Got {!r} instead.".format(
  769. allowed_target_types, type_of_target_y
  770. )
  771. )
  772. y = column_or_1d(y)
  773. _, y_inv, y_cnt = np.unique(y, return_inverse=True, return_counts=True)
  774. if np.all(self.n_splits > y_cnt):
  775. raise ValueError(
  776. "n_splits=%d cannot be greater than the"
  777. " number of members in each class." % (self.n_splits)
  778. )
  779. n_smallest_class = np.min(y_cnt)
  780. if self.n_splits > n_smallest_class:
  781. warnings.warn(
  782. "The least populated class in y has only %d"
  783. " members, which is less than n_splits=%d."
  784. % (n_smallest_class, self.n_splits),
  785. UserWarning,
  786. )
  787. n_classes = len(y_cnt)
  788. _, groups_inv, groups_cnt = np.unique(
  789. groups, return_inverse=True, return_counts=True
  790. )
  791. y_counts_per_group = np.zeros((len(groups_cnt), n_classes))
  792. for class_idx, group_idx in zip(y_inv, groups_inv):
  793. y_counts_per_group[group_idx, class_idx] += 1
  794. y_counts_per_fold = np.zeros((self.n_splits, n_classes))
  795. groups_per_fold = defaultdict(set)
  796. if self.shuffle:
  797. rng.shuffle(y_counts_per_group)
  798. # Stable sort to keep shuffled order for groups with the same
  799. # class distribution variance
  800. sorted_groups_idx = np.argsort(
  801. -np.std(y_counts_per_group, axis=1), kind="mergesort"
  802. )
  803. for group_idx in sorted_groups_idx:
  804. group_y_counts = y_counts_per_group[group_idx]
  805. best_fold = self._find_best_fold(
  806. y_counts_per_fold=y_counts_per_fold,
  807. y_cnt=y_cnt,
  808. group_y_counts=group_y_counts,
  809. )
  810. y_counts_per_fold[best_fold] += group_y_counts
  811. groups_per_fold[best_fold].add(group_idx)
  812. for i in range(self.n_splits):
  813. test_indices = [
  814. idx
  815. for idx, group_idx in enumerate(groups_inv)
  816. if group_idx in groups_per_fold[i]
  817. ]
  818. yield test_indices
  819. def _find_best_fold(self, y_counts_per_fold, y_cnt, group_y_counts):
  820. best_fold = None
  821. min_eval = np.inf
  822. min_samples_in_fold = np.inf
  823. for i in range(self.n_splits):
  824. y_counts_per_fold[i] += group_y_counts
  825. # Summarise the distribution over classes in each proposed fold
  826. std_per_class = np.std(y_counts_per_fold / y_cnt.reshape(1, -1), axis=0)
  827. y_counts_per_fold[i] -= group_y_counts
  828. fold_eval = np.mean(std_per_class)
  829. samples_in_fold = np.sum(y_counts_per_fold[i])
  830. is_current_fold_better = (
  831. fold_eval < min_eval
  832. or np.isclose(fold_eval, min_eval)
  833. and samples_in_fold < min_samples_in_fold
  834. )
  835. if is_current_fold_better:
  836. min_eval = fold_eval
  837. min_samples_in_fold = samples_in_fold
  838. best_fold = i
  839. return best_fold
  840. class TimeSeriesSplit(_BaseKFold):
  841. """Time Series cross-validator
  842. Provides train/test indices to split time series data samples
  843. that are observed at fixed time intervals, in train/test sets.
  844. In each split, test indices must be higher than before, and thus shuffling
  845. in cross validator is inappropriate.
  846. This cross-validation object is a variation of :class:`KFold`.
  847. In the kth split, it returns first k folds as train set and the
  848. (k+1)th fold as test set.
  849. Note that unlike standard cross-validation methods, successive
  850. training sets are supersets of those that come before them.
  851. Read more in the :ref:`User Guide <time_series_split>`.
  852. For visualisation of cross-validation behaviour and
  853. comparison between common scikit-learn split methods
  854. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  855. .. versionadded:: 0.18
  856. Parameters
  857. ----------
  858. n_splits : int, default=5
  859. Number of splits. Must be at least 2.
  860. .. versionchanged:: 0.22
  861. ``n_splits`` default value changed from 3 to 5.
  862. max_train_size : int, default=None
  863. Maximum size for a single training set.
  864. test_size : int, default=None
  865. Used to limit the size of the test set. Defaults to
  866. ``n_samples // (n_splits + 1)``, which is the maximum allowed value
  867. with ``gap=0``.
  868. .. versionadded:: 0.24
  869. gap : int, default=0
  870. Number of samples to exclude from the end of each train set before
  871. the test set.
  872. .. versionadded:: 0.24
  873. Examples
  874. --------
  875. >>> import numpy as np
  876. >>> from sklearn.model_selection import TimeSeriesSplit
  877. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
  878. >>> y = np.array([1, 2, 3, 4, 5, 6])
  879. >>> tscv = TimeSeriesSplit()
  880. >>> print(tscv)
  881. TimeSeriesSplit(gap=0, max_train_size=None, n_splits=5, test_size=None)
  882. >>> for i, (train_index, test_index) in enumerate(tscv.split(X)):
  883. ... print(f"Fold {i}:")
  884. ... print(f" Train: index={train_index}")
  885. ... print(f" Test: index={test_index}")
  886. Fold 0:
  887. Train: index=[0]
  888. Test: index=[1]
  889. Fold 1:
  890. Train: index=[0 1]
  891. Test: index=[2]
  892. Fold 2:
  893. Train: index=[0 1 2]
  894. Test: index=[3]
  895. Fold 3:
  896. Train: index=[0 1 2 3]
  897. Test: index=[4]
  898. Fold 4:
  899. Train: index=[0 1 2 3 4]
  900. Test: index=[5]
  901. >>> # Fix test_size to 2 with 12 samples
  902. >>> X = np.random.randn(12, 2)
  903. >>> y = np.random.randint(0, 2, 12)
  904. >>> tscv = TimeSeriesSplit(n_splits=3, test_size=2)
  905. >>> for i, (train_index, test_index) in enumerate(tscv.split(X)):
  906. ... print(f"Fold {i}:")
  907. ... print(f" Train: index={train_index}")
  908. ... print(f" Test: index={test_index}")
  909. Fold 0:
  910. Train: index=[0 1 2 3 4 5]
  911. Test: index=[6 7]
  912. Fold 1:
  913. Train: index=[0 1 2 3 4 5 6 7]
  914. Test: index=[8 9]
  915. Fold 2:
  916. Train: index=[0 1 2 3 4 5 6 7 8 9]
  917. Test: index=[10 11]
  918. >>> # Add in a 2 period gap
  919. >>> tscv = TimeSeriesSplit(n_splits=3, test_size=2, gap=2)
  920. >>> for i, (train_index, test_index) in enumerate(tscv.split(X)):
  921. ... print(f"Fold {i}:")
  922. ... print(f" Train: index={train_index}")
  923. ... print(f" Test: index={test_index}")
  924. Fold 0:
  925. Train: index=[0 1 2 3]
  926. Test: index=[6 7]
  927. Fold 1:
  928. Train: index=[0 1 2 3 4 5]
  929. Test: index=[8 9]
  930. Fold 2:
  931. Train: index=[0 1 2 3 4 5 6 7]
  932. Test: index=[10 11]
  933. Notes
  934. -----
  935. The training set has size ``i * n_samples // (n_splits + 1)
  936. + n_samples % (n_splits + 1)`` in the ``i`` th split,
  937. with a test set of size ``n_samples//(n_splits + 1)`` by default,
  938. where ``n_samples`` is the number of samples.
  939. """
  940. def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0):
  941. super().__init__(n_splits, shuffle=False, random_state=None)
  942. self.max_train_size = max_train_size
  943. self.test_size = test_size
  944. self.gap = gap
  945. def split(self, X, y=None, groups=None):
  946. """Generate indices to split data into training and test set.
  947. Parameters
  948. ----------
  949. X : array-like of shape (n_samples, n_features)
  950. Training data, where `n_samples` is the number of samples
  951. and `n_features` is the number of features.
  952. y : array-like of shape (n_samples,)
  953. Always ignored, exists for compatibility.
  954. groups : array-like of shape (n_samples,)
  955. Always ignored, exists for compatibility.
  956. Yields
  957. ------
  958. train : ndarray
  959. The training set indices for that split.
  960. test : ndarray
  961. The testing set indices for that split.
  962. """
  963. X, y, groups = indexable(X, y, groups)
  964. n_samples = _num_samples(X)
  965. n_splits = self.n_splits
  966. n_folds = n_splits + 1
  967. gap = self.gap
  968. test_size = (
  969. self.test_size if self.test_size is not None else n_samples // n_folds
  970. )
  971. # Make sure we have enough samples for the given split parameters
  972. if n_folds > n_samples:
  973. raise ValueError(
  974. f"Cannot have number of folds={n_folds} greater"
  975. f" than the number of samples={n_samples}."
  976. )
  977. if n_samples - gap - (test_size * n_splits) <= 0:
  978. raise ValueError(
  979. f"Too many splits={n_splits} for number of samples"
  980. f"={n_samples} with test_size={test_size} and gap={gap}."
  981. )
  982. indices = np.arange(n_samples)
  983. test_starts = range(n_samples - n_splits * test_size, n_samples, test_size)
  984. for test_start in test_starts:
  985. train_end = test_start - gap
  986. if self.max_train_size and self.max_train_size < train_end:
  987. yield (
  988. indices[train_end - self.max_train_size : train_end],
  989. indices[test_start : test_start + test_size],
  990. )
  991. else:
  992. yield (
  993. indices[:train_end],
  994. indices[test_start : test_start + test_size],
  995. )
  996. class LeaveOneGroupOut(GroupsConsumerMixin, BaseCrossValidator):
  997. """Leave One Group Out cross-validator
  998. Provides train/test indices to split data such that each training set is
  999. comprised of all samples except ones belonging to one specific group.
  1000. Arbitrary domain specific group information is provided an array integers
  1001. that encodes the group of each sample.
  1002. For instance the groups could be the year of collection of the samples
  1003. and thus allow for cross-validation against time-based splits.
  1004. Read more in the :ref:`User Guide <leave_one_group_out>`.
  1005. Notes
  1006. -----
  1007. Splits are ordered according to the index of the group left out. The first
  1008. split has testing set consisting of the group whose index in `groups` is
  1009. lowest, and so on.
  1010. Examples
  1011. --------
  1012. >>> import numpy as np
  1013. >>> from sklearn.model_selection import LeaveOneGroupOut
  1014. >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
  1015. >>> y = np.array([1, 2, 1, 2])
  1016. >>> groups = np.array([1, 1, 2, 2])
  1017. >>> logo = LeaveOneGroupOut()
  1018. >>> logo.get_n_splits(X, y, groups)
  1019. 2
  1020. >>> logo.get_n_splits(groups=groups) # 'groups' is always required
  1021. 2
  1022. >>> print(logo)
  1023. LeaveOneGroupOut()
  1024. >>> for i, (train_index, test_index) in enumerate(logo.split(X, y, groups)):
  1025. ... print(f"Fold {i}:")
  1026. ... print(f" Train: index={train_index}, group={groups[train_index]}")
  1027. ... print(f" Test: index={test_index}, group={groups[test_index]}")
  1028. Fold 0:
  1029. Train: index=[2 3], group=[2 2]
  1030. Test: index=[0 1], group=[1 1]
  1031. Fold 1:
  1032. Train: index=[0 1], group=[1 1]
  1033. Test: index=[2 3], group=[2 2]
  1034. See also
  1035. --------
  1036. GroupKFold: K-fold iterator variant with non-overlapping groups.
  1037. """
  1038. def _iter_test_masks(self, X, y, groups):
  1039. if groups is None:
  1040. raise ValueError("The 'groups' parameter should not be None.")
  1041. # We make a copy of groups to avoid side-effects during iteration
  1042. groups = check_array(
  1043. groups, input_name="groups", copy=True, ensure_2d=False, dtype=None
  1044. )
  1045. unique_groups = np.unique(groups)
  1046. if len(unique_groups) <= 1:
  1047. raise ValueError(
  1048. "The groups parameter contains fewer than 2 unique groups "
  1049. "(%s). LeaveOneGroupOut expects at least 2." % unique_groups
  1050. )
  1051. for i in unique_groups:
  1052. yield groups == i
  1053. def get_n_splits(self, X=None, y=None, groups=None):
  1054. """Returns the number of splitting iterations in the cross-validator
  1055. Parameters
  1056. ----------
  1057. X : object
  1058. Always ignored, exists for compatibility.
  1059. y : object
  1060. Always ignored, exists for compatibility.
  1061. groups : array-like of shape (n_samples,)
  1062. Group labels for the samples used while splitting the dataset into
  1063. train/test set. This 'groups' parameter must always be specified to
  1064. calculate the number of splits, though the other parameters can be
  1065. omitted.
  1066. Returns
  1067. -------
  1068. n_splits : int
  1069. Returns the number of splitting iterations in the cross-validator.
  1070. """
  1071. if groups is None:
  1072. raise ValueError("The 'groups' parameter should not be None.")
  1073. groups = check_array(groups, input_name="groups", ensure_2d=False, dtype=None)
  1074. return len(np.unique(groups))
  1075. def split(self, X, y=None, groups=None):
  1076. """Generate indices to split data into training and test set.
  1077. Parameters
  1078. ----------
  1079. X : array-like of shape (n_samples, n_features)
  1080. Training data, where `n_samples` is the number of samples
  1081. and `n_features` is the number of features.
  1082. y : array-like of shape (n_samples,), default=None
  1083. The target variable for supervised learning problems.
  1084. groups : array-like of shape (n_samples,)
  1085. Group labels for the samples used while splitting the dataset into
  1086. train/test set.
  1087. Yields
  1088. ------
  1089. train : ndarray
  1090. The training set indices for that split.
  1091. test : ndarray
  1092. The testing set indices for that split.
  1093. """
  1094. return super().split(X, y, groups)
  1095. class LeavePGroupsOut(GroupsConsumerMixin, BaseCrossValidator):
  1096. """Leave P Group(s) Out cross-validator
  1097. Provides train/test indices to split data according to a third-party
  1098. provided group. This group information can be used to encode arbitrary
  1099. domain specific stratifications of the samples as integers.
  1100. For instance the groups could be the year of collection of the samples
  1101. and thus allow for cross-validation against time-based splits.
  1102. The difference between LeavePGroupsOut and LeaveOneGroupOut is that
  1103. the former builds the test sets with all the samples assigned to
  1104. ``p`` different values of the groups while the latter uses samples
  1105. all assigned the same groups.
  1106. Read more in the :ref:`User Guide <leave_p_groups_out>`.
  1107. Parameters
  1108. ----------
  1109. n_groups : int
  1110. Number of groups (``p``) to leave out in the test split.
  1111. Examples
  1112. --------
  1113. >>> import numpy as np
  1114. >>> from sklearn.model_selection import LeavePGroupsOut
  1115. >>> X = np.array([[1, 2], [3, 4], [5, 6]])
  1116. >>> y = np.array([1, 2, 1])
  1117. >>> groups = np.array([1, 2, 3])
  1118. >>> lpgo = LeavePGroupsOut(n_groups=2)
  1119. >>> lpgo.get_n_splits(X, y, groups)
  1120. 3
  1121. >>> lpgo.get_n_splits(groups=groups) # 'groups' is always required
  1122. 3
  1123. >>> print(lpgo)
  1124. LeavePGroupsOut(n_groups=2)
  1125. >>> for i, (train_index, test_index) in enumerate(lpgo.split(X, y, groups)):
  1126. ... print(f"Fold {i}:")
  1127. ... print(f" Train: index={train_index}, group={groups[train_index]}")
  1128. ... print(f" Test: index={test_index}, group={groups[test_index]}")
  1129. Fold 0:
  1130. Train: index=[2], group=[3]
  1131. Test: index=[0 1], group=[1 2]
  1132. Fold 1:
  1133. Train: index=[1], group=[2]
  1134. Test: index=[0 2], group=[1 3]
  1135. Fold 2:
  1136. Train: index=[0], group=[1]
  1137. Test: index=[1 2], group=[2 3]
  1138. See Also
  1139. --------
  1140. GroupKFold : K-fold iterator variant with non-overlapping groups.
  1141. """
  1142. def __init__(self, n_groups):
  1143. self.n_groups = n_groups
  1144. def _iter_test_masks(self, X, y, groups):
  1145. if groups is None:
  1146. raise ValueError("The 'groups' parameter should not be None.")
  1147. groups = check_array(
  1148. groups, input_name="groups", copy=True, ensure_2d=False, dtype=None
  1149. )
  1150. unique_groups = np.unique(groups)
  1151. if self.n_groups >= len(unique_groups):
  1152. raise ValueError(
  1153. "The groups parameter contains fewer than (or equal to) "
  1154. "n_groups (%d) numbers of unique groups (%s). LeavePGroupsOut "
  1155. "expects that at least n_groups + 1 (%d) unique groups be "
  1156. "present" % (self.n_groups, unique_groups, self.n_groups + 1)
  1157. )
  1158. combi = combinations(range(len(unique_groups)), self.n_groups)
  1159. for indices in combi:
  1160. test_index = np.zeros(_num_samples(X), dtype=bool)
  1161. for l in unique_groups[np.array(indices)]:
  1162. test_index[groups == l] = True
  1163. yield test_index
  1164. def get_n_splits(self, X=None, y=None, groups=None):
  1165. """Returns the number of splitting iterations in the cross-validator
  1166. Parameters
  1167. ----------
  1168. X : object
  1169. Always ignored, exists for compatibility.
  1170. y : object
  1171. Always ignored, exists for compatibility.
  1172. groups : array-like of shape (n_samples,)
  1173. Group labels for the samples used while splitting the dataset into
  1174. train/test set. This 'groups' parameter must always be specified to
  1175. calculate the number of splits, though the other parameters can be
  1176. omitted.
  1177. Returns
  1178. -------
  1179. n_splits : int
  1180. Returns the number of splitting iterations in the cross-validator.
  1181. """
  1182. if groups is None:
  1183. raise ValueError("The 'groups' parameter should not be None.")
  1184. groups = check_array(groups, input_name="groups", ensure_2d=False, dtype=None)
  1185. return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
  1186. def split(self, X, y=None, groups=None):
  1187. """Generate indices to split data into training and test set.
  1188. Parameters
  1189. ----------
  1190. X : array-like of shape (n_samples, n_features)
  1191. Training data, where `n_samples` is the number of samples
  1192. and `n_features` is the number of features.
  1193. y : array-like of shape (n_samples,), default=None
  1194. The target variable for supervised learning problems.
  1195. groups : array-like of shape (n_samples,)
  1196. Group labels for the samples used while splitting the dataset into
  1197. train/test set.
  1198. Yields
  1199. ------
  1200. train : ndarray
  1201. The training set indices for that split.
  1202. test : ndarray
  1203. The testing set indices for that split.
  1204. """
  1205. return super().split(X, y, groups)
  1206. class _RepeatedSplits(_MetadataRequester, metaclass=ABCMeta):
  1207. """Repeated splits for an arbitrary randomized CV splitter.
  1208. Repeats splits for cross-validators n times with different randomization
  1209. in each repetition.
  1210. Parameters
  1211. ----------
  1212. cv : callable
  1213. Cross-validator class.
  1214. n_repeats : int, default=10
  1215. Number of times cross-validator needs to be repeated.
  1216. random_state : int, RandomState instance or None, default=None
  1217. Passes `random_state` to the arbitrary repeating cross validator.
  1218. Pass an int for reproducible output across multiple function calls.
  1219. See :term:`Glossary <random_state>`.
  1220. **cvargs : additional params
  1221. Constructor parameters for cv. Must not contain random_state
  1222. and shuffle.
  1223. """
  1224. # This indicates that by default CV splitters don't have a "groups" kwarg,
  1225. # unless indicated by inheriting from ``GroupsConsumerMixin``.
  1226. # This also prevents ``set_split_request`` to be generated for splitters
  1227. # which don't support ``groups``.
  1228. __metadata_request__split = {"groups": metadata_routing.UNUSED}
  1229. def __init__(self, cv, *, n_repeats=10, random_state=None, **cvargs):
  1230. if not isinstance(n_repeats, numbers.Integral):
  1231. raise ValueError("Number of repetitions must be of Integral type.")
  1232. if n_repeats <= 0:
  1233. raise ValueError("Number of repetitions must be greater than 0.")
  1234. if any(key in cvargs for key in ("random_state", "shuffle")):
  1235. raise ValueError("cvargs must not contain random_state or shuffle.")
  1236. self.cv = cv
  1237. self.n_repeats = n_repeats
  1238. self.random_state = random_state
  1239. self.cvargs = cvargs
  1240. def split(self, X, y=None, groups=None):
  1241. """Generates indices to split data into training and test set.
  1242. Parameters
  1243. ----------
  1244. X : array-like of shape (n_samples, n_features)
  1245. Training data, where `n_samples` is the number of samples
  1246. and `n_features` is the number of features.
  1247. y : array-like of shape (n_samples,)
  1248. The target variable for supervised learning problems.
  1249. groups : array-like of shape (n_samples,), default=None
  1250. Group labels for the samples used while splitting the dataset into
  1251. train/test set.
  1252. Yields
  1253. ------
  1254. train : ndarray
  1255. The training set indices for that split.
  1256. test : ndarray
  1257. The testing set indices for that split.
  1258. """
  1259. n_repeats = self.n_repeats
  1260. rng = check_random_state(self.random_state)
  1261. for idx in range(n_repeats):
  1262. cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
  1263. for train_index, test_index in cv.split(X, y, groups):
  1264. yield train_index, test_index
  1265. def get_n_splits(self, X=None, y=None, groups=None):
  1266. """Returns the number of splitting iterations in the cross-validator
  1267. Parameters
  1268. ----------
  1269. X : object
  1270. Always ignored, exists for compatibility.
  1271. ``np.zeros(n_samples)`` may be used as a placeholder.
  1272. y : object
  1273. Always ignored, exists for compatibility.
  1274. ``np.zeros(n_samples)`` may be used as a placeholder.
  1275. groups : array-like of shape (n_samples,), default=None
  1276. Group labels for the samples used while splitting the dataset into
  1277. train/test set.
  1278. Returns
  1279. -------
  1280. n_splits : int
  1281. Returns the number of splitting iterations in the cross-validator.
  1282. """
  1283. rng = check_random_state(self.random_state)
  1284. cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
  1285. return cv.get_n_splits(X, y, groups) * self.n_repeats
  1286. def __repr__(self):
  1287. return _build_repr(self)
  1288. class RepeatedKFold(_RepeatedSplits):
  1289. """Repeated K-Fold cross validator.
  1290. Repeats K-Fold n times with different randomization in each repetition.
  1291. Read more in the :ref:`User Guide <repeated_k_fold>`.
  1292. Parameters
  1293. ----------
  1294. n_splits : int, default=5
  1295. Number of folds. Must be at least 2.
  1296. n_repeats : int, default=10
  1297. Number of times cross-validator needs to be repeated.
  1298. random_state : int, RandomState instance or None, default=None
  1299. Controls the randomness of each repeated cross-validation instance.
  1300. Pass an int for reproducible output across multiple function calls.
  1301. See :term:`Glossary <random_state>`.
  1302. Examples
  1303. --------
  1304. >>> import numpy as np
  1305. >>> from sklearn.model_selection import RepeatedKFold
  1306. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
  1307. >>> y = np.array([0, 0, 1, 1])
  1308. >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124)
  1309. >>> rkf.get_n_splits(X, y)
  1310. 4
  1311. >>> print(rkf)
  1312. RepeatedKFold(n_repeats=2, n_splits=2, random_state=2652124)
  1313. >>> for i, (train_index, test_index) in enumerate(rkf.split(X)):
  1314. ... print(f"Fold {i}:")
  1315. ... print(f" Train: index={train_index}")
  1316. ... print(f" Test: index={test_index}")
  1317. ...
  1318. Fold 0:
  1319. Train: index=[0 1]
  1320. Test: index=[2 3]
  1321. Fold 1:
  1322. Train: index=[2 3]
  1323. Test: index=[0 1]
  1324. Fold 2:
  1325. Train: index=[1 2]
  1326. Test: index=[0 3]
  1327. Fold 3:
  1328. Train: index=[0 3]
  1329. Test: index=[1 2]
  1330. Notes
  1331. -----
  1332. Randomized CV splitters may return different results for each call of
  1333. split. You can make the results identical by setting `random_state`
  1334. to an integer.
  1335. See Also
  1336. --------
  1337. RepeatedStratifiedKFold : Repeats Stratified K-Fold n times.
  1338. """
  1339. def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
  1340. super().__init__(
  1341. KFold, n_repeats=n_repeats, random_state=random_state, n_splits=n_splits
  1342. )
  1343. class RepeatedStratifiedKFold(_RepeatedSplits):
  1344. """Repeated Stratified K-Fold cross validator.
  1345. Repeats Stratified K-Fold n times with different randomization in each
  1346. repetition.
  1347. Read more in the :ref:`User Guide <repeated_k_fold>`.
  1348. Parameters
  1349. ----------
  1350. n_splits : int, default=5
  1351. Number of folds. Must be at least 2.
  1352. n_repeats : int, default=10
  1353. Number of times cross-validator needs to be repeated.
  1354. random_state : int, RandomState instance or None, default=None
  1355. Controls the generation of the random states for each repetition.
  1356. Pass an int for reproducible output across multiple function calls.
  1357. See :term:`Glossary <random_state>`.
  1358. Examples
  1359. --------
  1360. >>> import numpy as np
  1361. >>> from sklearn.model_selection import RepeatedStratifiedKFold
  1362. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
  1363. >>> y = np.array([0, 0, 1, 1])
  1364. >>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2,
  1365. ... random_state=36851234)
  1366. >>> rskf.get_n_splits(X, y)
  1367. 4
  1368. >>> print(rskf)
  1369. RepeatedStratifiedKFold(n_repeats=2, n_splits=2, random_state=36851234)
  1370. >>> for i, (train_index, test_index) in enumerate(rskf.split(X, y)):
  1371. ... print(f"Fold {i}:")
  1372. ... print(f" Train: index={train_index}")
  1373. ... print(f" Test: index={test_index}")
  1374. ...
  1375. Fold 0:
  1376. Train: index=[1 2]
  1377. Test: index=[0 3]
  1378. Fold 1:
  1379. Train: index=[0 3]
  1380. Test: index=[1 2]
  1381. Fold 2:
  1382. Train: index=[1 3]
  1383. Test: index=[0 2]
  1384. Fold 3:
  1385. Train: index=[0 2]
  1386. Test: index=[1 3]
  1387. Notes
  1388. -----
  1389. Randomized CV splitters may return different results for each call of
  1390. split. You can make the results identical by setting `random_state`
  1391. to an integer.
  1392. See Also
  1393. --------
  1394. RepeatedKFold : Repeats K-Fold n times.
  1395. """
  1396. def __init__(self, *, n_splits=5, n_repeats=10, random_state=None):
  1397. super().__init__(
  1398. StratifiedKFold,
  1399. n_repeats=n_repeats,
  1400. random_state=random_state,
  1401. n_splits=n_splits,
  1402. )
  1403. class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta):
  1404. """Base class for ShuffleSplit and StratifiedShuffleSplit"""
  1405. # This indicates that by default CV splitters don't have a "groups" kwarg,
  1406. # unless indicated by inheriting from ``GroupsConsumerMixin``.
  1407. # This also prevents ``set_split_request`` to be generated for splitters
  1408. # which don't support ``groups``.
  1409. __metadata_request__split = {"groups": metadata_routing.UNUSED}
  1410. def __init__(
  1411. self, n_splits=10, *, test_size=None, train_size=None, random_state=None
  1412. ):
  1413. self.n_splits = n_splits
  1414. self.test_size = test_size
  1415. self.train_size = train_size
  1416. self.random_state = random_state
  1417. self._default_test_size = 0.1
  1418. def split(self, X, y=None, groups=None):
  1419. """Generate indices to split data into training and test set.
  1420. Parameters
  1421. ----------
  1422. X : array-like of shape (n_samples, n_features)
  1423. Training data, where `n_samples` is the number of samples
  1424. and `n_features` is the number of features.
  1425. y : array-like of shape (n_samples,)
  1426. The target variable for supervised learning problems.
  1427. groups : array-like of shape (n_samples,), default=None
  1428. Group labels for the samples used while splitting the dataset into
  1429. train/test set.
  1430. Yields
  1431. ------
  1432. train : ndarray
  1433. The training set indices for that split.
  1434. test : ndarray
  1435. The testing set indices for that split.
  1436. Notes
  1437. -----
  1438. Randomized CV splitters may return different results for each call of
  1439. split. You can make the results identical by setting `random_state`
  1440. to an integer.
  1441. """
  1442. X, y, groups = indexable(X, y, groups)
  1443. for train, test in self._iter_indices(X, y, groups):
  1444. yield train, test
  1445. @abstractmethod
  1446. def _iter_indices(self, X, y=None, groups=None):
  1447. """Generate (train, test) indices"""
  1448. def get_n_splits(self, X=None, y=None, groups=None):
  1449. """Returns the number of splitting iterations in the cross-validator
  1450. Parameters
  1451. ----------
  1452. X : object
  1453. Always ignored, exists for compatibility.
  1454. y : object
  1455. Always ignored, exists for compatibility.
  1456. groups : object
  1457. Always ignored, exists for compatibility.
  1458. Returns
  1459. -------
  1460. n_splits : int
  1461. Returns the number of splitting iterations in the cross-validator.
  1462. """
  1463. return self.n_splits
  1464. def __repr__(self):
  1465. return _build_repr(self)
  1466. class ShuffleSplit(BaseShuffleSplit):
  1467. """Random permutation cross-validator
  1468. Yields indices to split data into training and test sets.
  1469. Note: contrary to other cross-validation strategies, random splits
  1470. do not guarantee that all folds will be different, although this is
  1471. still very likely for sizeable datasets.
  1472. Read more in the :ref:`User Guide <ShuffleSplit>`.
  1473. For visualisation of cross-validation behaviour and
  1474. comparison between common scikit-learn split methods
  1475. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  1476. Parameters
  1477. ----------
  1478. n_splits : int, default=10
  1479. Number of re-shuffling & splitting iterations.
  1480. test_size : float or int, default=None
  1481. If float, should be between 0.0 and 1.0 and represent the proportion
  1482. of the dataset to include in the test split. If int, represents the
  1483. absolute number of test samples. If None, the value is set to the
  1484. complement of the train size. If ``train_size`` is also None, it will
  1485. be set to 0.1.
  1486. train_size : float or int, default=None
  1487. If float, should be between 0.0 and 1.0 and represent the
  1488. proportion of the dataset to include in the train split. If
  1489. int, represents the absolute number of train samples. If None,
  1490. the value is automatically set to the complement of the test size.
  1491. random_state : int, RandomState instance or None, default=None
  1492. Controls the randomness of the training and testing indices produced.
  1493. Pass an int for reproducible output across multiple function calls.
  1494. See :term:`Glossary <random_state>`.
  1495. Examples
  1496. --------
  1497. >>> import numpy as np
  1498. >>> from sklearn.model_selection import ShuffleSplit
  1499. >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [3, 4], [5, 6]])
  1500. >>> y = np.array([1, 2, 1, 2, 1, 2])
  1501. >>> rs = ShuffleSplit(n_splits=5, test_size=.25, random_state=0)
  1502. >>> rs.get_n_splits(X)
  1503. 5
  1504. >>> print(rs)
  1505. ShuffleSplit(n_splits=5, random_state=0, test_size=0.25, train_size=None)
  1506. >>> for i, (train_index, test_index) in enumerate(rs.split(X)):
  1507. ... print(f"Fold {i}:")
  1508. ... print(f" Train: index={train_index}")
  1509. ... print(f" Test: index={test_index}")
  1510. Fold 0:
  1511. Train: index=[1 3 0 4]
  1512. Test: index=[5 2]
  1513. Fold 1:
  1514. Train: index=[4 0 2 5]
  1515. Test: index=[1 3]
  1516. Fold 2:
  1517. Train: index=[1 2 4 0]
  1518. Test: index=[3 5]
  1519. Fold 3:
  1520. Train: index=[3 4 1 0]
  1521. Test: index=[5 2]
  1522. Fold 4:
  1523. Train: index=[3 5 1 0]
  1524. Test: index=[2 4]
  1525. >>> # Specify train and test size
  1526. >>> rs = ShuffleSplit(n_splits=5, train_size=0.5, test_size=.25,
  1527. ... random_state=0)
  1528. >>> for i, (train_index, test_index) in enumerate(rs.split(X)):
  1529. ... print(f"Fold {i}:")
  1530. ... print(f" Train: index={train_index}")
  1531. ... print(f" Test: index={test_index}")
  1532. Fold 0:
  1533. Train: index=[1 3 0]
  1534. Test: index=[5 2]
  1535. Fold 1:
  1536. Train: index=[4 0 2]
  1537. Test: index=[1 3]
  1538. Fold 2:
  1539. Train: index=[1 2 4]
  1540. Test: index=[3 5]
  1541. Fold 3:
  1542. Train: index=[3 4 1]
  1543. Test: index=[5 2]
  1544. Fold 4:
  1545. Train: index=[3 5 1]
  1546. Test: index=[2 4]
  1547. """
  1548. def __init__(
  1549. self, n_splits=10, *, test_size=None, train_size=None, random_state=None
  1550. ):
  1551. super().__init__(
  1552. n_splits=n_splits,
  1553. test_size=test_size,
  1554. train_size=train_size,
  1555. random_state=random_state,
  1556. )
  1557. self._default_test_size = 0.1
  1558. def _iter_indices(self, X, y=None, groups=None):
  1559. n_samples = _num_samples(X)
  1560. n_train, n_test = _validate_shuffle_split(
  1561. n_samples,
  1562. self.test_size,
  1563. self.train_size,
  1564. default_test_size=self._default_test_size,
  1565. )
  1566. rng = check_random_state(self.random_state)
  1567. for i in range(self.n_splits):
  1568. # random partition
  1569. permutation = rng.permutation(n_samples)
  1570. ind_test = permutation[:n_test]
  1571. ind_train = permutation[n_test : (n_test + n_train)]
  1572. yield ind_train, ind_test
  1573. class GroupShuffleSplit(GroupsConsumerMixin, ShuffleSplit):
  1574. """Shuffle-Group(s)-Out cross-validation iterator
  1575. Provides randomized train/test indices to split data according to a
  1576. third-party provided group. This group information can be used to encode
  1577. arbitrary domain specific stratifications of the samples as integers.
  1578. For instance the groups could be the year of collection of the samples
  1579. and thus allow for cross-validation against time-based splits.
  1580. The difference between LeavePGroupsOut and GroupShuffleSplit is that
  1581. the former generates splits using all subsets of size ``p`` unique groups,
  1582. whereas GroupShuffleSplit generates a user-determined number of random
  1583. test splits, each with a user-determined fraction of unique groups.
  1584. For example, a less computationally intensive alternative to
  1585. ``LeavePGroupsOut(p=10)`` would be
  1586. ``GroupShuffleSplit(test_size=10, n_splits=100)``.
  1587. Note: The parameters ``test_size`` and ``train_size`` refer to groups, and
  1588. not to samples, as in ShuffleSplit.
  1589. Read more in the :ref:`User Guide <group_shuffle_split>`.
  1590. For visualisation of cross-validation behaviour and
  1591. comparison between common scikit-learn split methods
  1592. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  1593. Parameters
  1594. ----------
  1595. n_splits : int, default=5
  1596. Number of re-shuffling & splitting iterations.
  1597. test_size : float, int, default=0.2
  1598. If float, should be between 0.0 and 1.0 and represent the proportion
  1599. of groups to include in the test split (rounded up). If int,
  1600. represents the absolute number of test groups. If None, the value is
  1601. set to the complement of the train size.
  1602. The default will change in version 0.21. It will remain 0.2 only
  1603. if ``train_size`` is unspecified, otherwise it will complement
  1604. the specified ``train_size``.
  1605. train_size : float or int, default=None
  1606. If float, should be between 0.0 and 1.0 and represent the
  1607. proportion of the groups to include in the train split. If
  1608. int, represents the absolute number of train groups. If None,
  1609. the value is automatically set to the complement of the test size.
  1610. random_state : int, RandomState instance or None, default=None
  1611. Controls the randomness of the training and testing indices produced.
  1612. Pass an int for reproducible output across multiple function calls.
  1613. See :term:`Glossary <random_state>`.
  1614. Examples
  1615. --------
  1616. >>> import numpy as np
  1617. >>> from sklearn.model_selection import GroupShuffleSplit
  1618. >>> X = np.ones(shape=(8, 2))
  1619. >>> y = np.ones(shape=(8, 1))
  1620. >>> groups = np.array([1, 1, 2, 2, 2, 3, 3, 3])
  1621. >>> print(groups.shape)
  1622. (8,)
  1623. >>> gss = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)
  1624. >>> gss.get_n_splits()
  1625. 2
  1626. >>> print(gss)
  1627. GroupShuffleSplit(n_splits=2, random_state=42, test_size=None, train_size=0.7)
  1628. >>> for i, (train_index, test_index) in enumerate(gss.split(X, y, groups)):
  1629. ... print(f"Fold {i}:")
  1630. ... print(f" Train: index={train_index}, group={groups[train_index]}")
  1631. ... print(f" Test: index={test_index}, group={groups[test_index]}")
  1632. Fold 0:
  1633. Train: index=[2 3 4 5 6 7], group=[2 2 2 3 3 3]
  1634. Test: index=[0 1], group=[1 1]
  1635. Fold 1:
  1636. Train: index=[0 1 5 6 7], group=[1 1 3 3 3]
  1637. Test: index=[2 3 4], group=[2 2 2]
  1638. See Also
  1639. --------
  1640. ShuffleSplit : Shuffles samples to create independent test/train sets.
  1641. LeavePGroupsOut : Train set leaves out all possible subsets of `p` groups.
  1642. """
  1643. def __init__(
  1644. self, n_splits=5, *, test_size=None, train_size=None, random_state=None
  1645. ):
  1646. super().__init__(
  1647. n_splits=n_splits,
  1648. test_size=test_size,
  1649. train_size=train_size,
  1650. random_state=random_state,
  1651. )
  1652. self._default_test_size = 0.2
  1653. def _iter_indices(self, X, y, groups):
  1654. if groups is None:
  1655. raise ValueError("The 'groups' parameter should not be None.")
  1656. groups = check_array(groups, input_name="groups", ensure_2d=False, dtype=None)
  1657. classes, group_indices = np.unique(groups, return_inverse=True)
  1658. for group_train, group_test in super()._iter_indices(X=classes):
  1659. # these are the indices of classes in the partition
  1660. # invert them into data indices
  1661. train = np.flatnonzero(np.isin(group_indices, group_train))
  1662. test = np.flatnonzero(np.isin(group_indices, group_test))
  1663. yield train, test
  1664. def split(self, X, y=None, groups=None):
  1665. """Generate indices to split data into training and test set.
  1666. Parameters
  1667. ----------
  1668. X : array-like of shape (n_samples, n_features)
  1669. Training data, where `n_samples` is the number of samples
  1670. and `n_features` is the number of features.
  1671. y : array-like of shape (n_samples,), default=None
  1672. The target variable for supervised learning problems.
  1673. groups : array-like of shape (n_samples,)
  1674. Group labels for the samples used while splitting the dataset into
  1675. train/test set.
  1676. Yields
  1677. ------
  1678. train : ndarray
  1679. The training set indices for that split.
  1680. test : ndarray
  1681. The testing set indices for that split.
  1682. Notes
  1683. -----
  1684. Randomized CV splitters may return different results for each call of
  1685. split. You can make the results identical by setting `random_state`
  1686. to an integer.
  1687. """
  1688. return super().split(X, y, groups)
  1689. class StratifiedShuffleSplit(BaseShuffleSplit):
  1690. """Stratified ShuffleSplit cross-validator
  1691. Provides train/test indices to split data in train/test sets.
  1692. This cross-validation object is a merge of StratifiedKFold and
  1693. ShuffleSplit, which returns stratified randomized folds. The folds
  1694. are made by preserving the percentage of samples for each class.
  1695. Note: like the ShuffleSplit strategy, stratified random splits
  1696. do not guarantee that all folds will be different, although this is
  1697. still very likely for sizeable datasets.
  1698. Read more in the :ref:`User Guide <stratified_shuffle_split>`.
  1699. For visualisation of cross-validation behaviour and
  1700. comparison between common scikit-learn split methods
  1701. refer to :ref:`sphx_glr_auto_examples_model_selection_plot_cv_indices.py`
  1702. Parameters
  1703. ----------
  1704. n_splits : int, default=10
  1705. Number of re-shuffling & splitting iterations.
  1706. test_size : float or int, default=None
  1707. If float, should be between 0.0 and 1.0 and represent the proportion
  1708. of the dataset to include in the test split. If int, represents the
  1709. absolute number of test samples. If None, the value is set to the
  1710. complement of the train size. If ``train_size`` is also None, it will
  1711. be set to 0.1.
  1712. train_size : float or int, default=None
  1713. If float, should be between 0.0 and 1.0 and represent the
  1714. proportion of the dataset to include in the train split. If
  1715. int, represents the absolute number of train samples. If None,
  1716. the value is automatically set to the complement of the test size.
  1717. random_state : int, RandomState instance or None, default=None
  1718. Controls the randomness of the training and testing indices produced.
  1719. Pass an int for reproducible output across multiple function calls.
  1720. See :term:`Glossary <random_state>`.
  1721. Examples
  1722. --------
  1723. >>> import numpy as np
  1724. >>> from sklearn.model_selection import StratifiedShuffleSplit
  1725. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
  1726. >>> y = np.array([0, 0, 0, 1, 1, 1])
  1727. >>> sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
  1728. >>> sss.get_n_splits(X, y)
  1729. 5
  1730. >>> print(sss)
  1731. StratifiedShuffleSplit(n_splits=5, random_state=0, ...)
  1732. >>> for i, (train_index, test_index) in enumerate(sss.split(X, y)):
  1733. ... print(f"Fold {i}:")
  1734. ... print(f" Train: index={train_index}")
  1735. ... print(f" Test: index={test_index}")
  1736. Fold 0:
  1737. Train: index=[5 2 3]
  1738. Test: index=[4 1 0]
  1739. Fold 1:
  1740. Train: index=[5 1 4]
  1741. Test: index=[0 2 3]
  1742. Fold 2:
  1743. Train: index=[5 0 2]
  1744. Test: index=[4 3 1]
  1745. Fold 3:
  1746. Train: index=[4 1 0]
  1747. Test: index=[2 3 5]
  1748. Fold 4:
  1749. Train: index=[0 5 1]
  1750. Test: index=[3 4 2]
  1751. """
  1752. def __init__(
  1753. self, n_splits=10, *, test_size=None, train_size=None, random_state=None
  1754. ):
  1755. super().__init__(
  1756. n_splits=n_splits,
  1757. test_size=test_size,
  1758. train_size=train_size,
  1759. random_state=random_state,
  1760. )
  1761. self._default_test_size = 0.1
  1762. def _iter_indices(self, X, y, groups=None):
  1763. n_samples = _num_samples(X)
  1764. y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
  1765. n_train, n_test = _validate_shuffle_split(
  1766. n_samples,
  1767. self.test_size,
  1768. self.train_size,
  1769. default_test_size=self._default_test_size,
  1770. )
  1771. if y.ndim == 2:
  1772. # for multi-label y, map each distinct row to a string repr
  1773. # using join because str(row) uses an ellipsis if len(row) > 1000
  1774. y = np.array([" ".join(row.astype("str")) for row in y])
  1775. classes, y_indices = np.unique(y, return_inverse=True)
  1776. n_classes = classes.shape[0]
  1777. class_counts = np.bincount(y_indices)
  1778. if np.min(class_counts) < 2:
  1779. raise ValueError(
  1780. "The least populated class in y has only 1"
  1781. " member, which is too few. The minimum"
  1782. " number of groups for any class cannot"
  1783. " be less than 2."
  1784. )
  1785. if n_train < n_classes:
  1786. raise ValueError(
  1787. "The train_size = %d should be greater or "
  1788. "equal to the number of classes = %d" % (n_train, n_classes)
  1789. )
  1790. if n_test < n_classes:
  1791. raise ValueError(
  1792. "The test_size = %d should be greater or "
  1793. "equal to the number of classes = %d" % (n_test, n_classes)
  1794. )
  1795. # Find the sorted list of instances for each class:
  1796. # (np.unique above performs a sort, so code is O(n logn) already)
  1797. class_indices = np.split(
  1798. np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
  1799. )
  1800. rng = check_random_state(self.random_state)
  1801. for _ in range(self.n_splits):
  1802. # if there are ties in the class-counts, we want
  1803. # to make sure to break them anew in each iteration
  1804. n_i = _approximate_mode(class_counts, n_train, rng)
  1805. class_counts_remaining = class_counts - n_i
  1806. t_i = _approximate_mode(class_counts_remaining, n_test, rng)
  1807. train = []
  1808. test = []
  1809. for i in range(n_classes):
  1810. permutation = rng.permutation(class_counts[i])
  1811. perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
  1812. train.extend(perm_indices_class_i[: n_i[i]])
  1813. test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
  1814. train = rng.permutation(train)
  1815. test = rng.permutation(test)
  1816. yield train, test
  1817. def split(self, X, y, groups=None):
  1818. """Generate indices to split data into training and test set.
  1819. Parameters
  1820. ----------
  1821. X : array-like of shape (n_samples, n_features)
  1822. Training data, where `n_samples` is the number of samples
  1823. and `n_features` is the number of features.
  1824. Note that providing ``y`` is sufficient to generate the splits and
  1825. hence ``np.zeros(n_samples)`` may be used as a placeholder for
  1826. ``X`` instead of actual training data.
  1827. y : array-like of shape (n_samples,) or (n_samples, n_labels)
  1828. The target variable for supervised learning problems.
  1829. Stratification is done based on the y labels.
  1830. groups : object
  1831. Always ignored, exists for compatibility.
  1832. Yields
  1833. ------
  1834. train : ndarray
  1835. The training set indices for that split.
  1836. test : ndarray
  1837. The testing set indices for that split.
  1838. Notes
  1839. -----
  1840. Randomized CV splitters may return different results for each call of
  1841. split. You can make the results identical by setting `random_state`
  1842. to an integer.
  1843. """
  1844. y = check_array(y, input_name="y", ensure_2d=False, dtype=None)
  1845. return super().split(X, y, groups)
  1846. def _validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None):
  1847. """
  1848. Validation helper to check if the test/test sizes are meaningful w.r.t. the
  1849. size of the data (n_samples).
  1850. """
  1851. if test_size is None and train_size is None:
  1852. test_size = default_test_size
  1853. test_size_type = np.asarray(test_size).dtype.kind
  1854. train_size_type = np.asarray(train_size).dtype.kind
  1855. if (
  1856. test_size_type == "i"
  1857. and (test_size >= n_samples or test_size <= 0)
  1858. or test_size_type == "f"
  1859. and (test_size <= 0 or test_size >= 1)
  1860. ):
  1861. raise ValueError(
  1862. "test_size={0} should be either positive and smaller"
  1863. " than the number of samples {1} or a float in the "
  1864. "(0, 1) range".format(test_size, n_samples)
  1865. )
  1866. if (
  1867. train_size_type == "i"
  1868. and (train_size >= n_samples or train_size <= 0)
  1869. or train_size_type == "f"
  1870. and (train_size <= 0 or train_size >= 1)
  1871. ):
  1872. raise ValueError(
  1873. "train_size={0} should be either positive and smaller"
  1874. " than the number of samples {1} or a float in the "
  1875. "(0, 1) range".format(train_size, n_samples)
  1876. )
  1877. if train_size is not None and train_size_type not in ("i", "f"):
  1878. raise ValueError("Invalid value for train_size: {}".format(train_size))
  1879. if test_size is not None and test_size_type not in ("i", "f"):
  1880. raise ValueError("Invalid value for test_size: {}".format(test_size))
  1881. if train_size_type == "f" and test_size_type == "f" and train_size + test_size > 1:
  1882. raise ValueError(
  1883. "The sum of test_size and train_size = {}, should be in the (0, 1)"
  1884. " range. Reduce test_size and/or train_size.".format(train_size + test_size)
  1885. )
  1886. if test_size_type == "f":
  1887. n_test = ceil(test_size * n_samples)
  1888. elif test_size_type == "i":
  1889. n_test = float(test_size)
  1890. if train_size_type == "f":
  1891. n_train = floor(train_size * n_samples)
  1892. elif train_size_type == "i":
  1893. n_train = float(train_size)
  1894. if train_size is None:
  1895. n_train = n_samples - n_test
  1896. elif test_size is None:
  1897. n_test = n_samples - n_train
  1898. if n_train + n_test > n_samples:
  1899. raise ValueError(
  1900. "The sum of train_size and test_size = %d, "
  1901. "should be smaller than the number of "
  1902. "samples %d. Reduce test_size and/or "
  1903. "train_size." % (n_train + n_test, n_samples)
  1904. )
  1905. n_train, n_test = int(n_train), int(n_test)
  1906. if n_train == 0:
  1907. raise ValueError(
  1908. "With n_samples={}, test_size={} and train_size={}, the "
  1909. "resulting train set will be empty. Adjust any of the "
  1910. "aforementioned parameters.".format(n_samples, test_size, train_size)
  1911. )
  1912. return n_train, n_test
  1913. class PredefinedSplit(BaseCrossValidator):
  1914. """Predefined split cross-validator
  1915. Provides train/test indices to split data into train/test sets using a
  1916. predefined scheme specified by the user with the ``test_fold`` parameter.
  1917. Read more in the :ref:`User Guide <predefined_split>`.
  1918. .. versionadded:: 0.16
  1919. Parameters
  1920. ----------
  1921. test_fold : array-like of shape (n_samples,)
  1922. The entry ``test_fold[i]`` represents the index of the test set that
  1923. sample ``i`` belongs to. It is possible to exclude sample ``i`` from
  1924. any test set (i.e. include sample ``i`` in every training set) by
  1925. setting ``test_fold[i]`` equal to -1.
  1926. Examples
  1927. --------
  1928. >>> import numpy as np
  1929. >>> from sklearn.model_selection import PredefinedSplit
  1930. >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
  1931. >>> y = np.array([0, 0, 1, 1])
  1932. >>> test_fold = [0, 1, -1, 1]
  1933. >>> ps = PredefinedSplit(test_fold)
  1934. >>> ps.get_n_splits()
  1935. 2
  1936. >>> print(ps)
  1937. PredefinedSplit(test_fold=array([ 0, 1, -1, 1]))
  1938. >>> for i, (train_index, test_index) in enumerate(ps.split()):
  1939. ... print(f"Fold {i}:")
  1940. ... print(f" Train: index={train_index}")
  1941. ... print(f" Test: index={test_index}")
  1942. Fold 0:
  1943. Train: index=[1 2 3]
  1944. Test: index=[0]
  1945. Fold 1:
  1946. Train: index=[0 2]
  1947. Test: index=[1 3]
  1948. """
  1949. def __init__(self, test_fold):
  1950. self.test_fold = np.array(test_fold, dtype=int)
  1951. self.test_fold = column_or_1d(self.test_fold)
  1952. self.unique_folds = np.unique(self.test_fold)
  1953. self.unique_folds = self.unique_folds[self.unique_folds != -1]
  1954. def split(self, X=None, y=None, groups=None):
  1955. """Generate indices to split data into training and test set.
  1956. Parameters
  1957. ----------
  1958. X : object
  1959. Always ignored, exists for compatibility.
  1960. y : object
  1961. Always ignored, exists for compatibility.
  1962. groups : object
  1963. Always ignored, exists for compatibility.
  1964. Yields
  1965. ------
  1966. train : ndarray
  1967. The training set indices for that split.
  1968. test : ndarray
  1969. The testing set indices for that split.
  1970. """
  1971. ind = np.arange(len(self.test_fold))
  1972. for test_index in self._iter_test_masks():
  1973. train_index = ind[np.logical_not(test_index)]
  1974. test_index = ind[test_index]
  1975. yield train_index, test_index
  1976. def _iter_test_masks(self):
  1977. """Generates boolean masks corresponding to test sets."""
  1978. for f in self.unique_folds:
  1979. test_index = np.where(self.test_fold == f)[0]
  1980. test_mask = np.zeros(len(self.test_fold), dtype=bool)
  1981. test_mask[test_index] = True
  1982. yield test_mask
  1983. def get_n_splits(self, X=None, y=None, groups=None):
  1984. """Returns the number of splitting iterations in the cross-validator
  1985. Parameters
  1986. ----------
  1987. X : object
  1988. Always ignored, exists for compatibility.
  1989. y : object
  1990. Always ignored, exists for compatibility.
  1991. groups : object
  1992. Always ignored, exists for compatibility.
  1993. Returns
  1994. -------
  1995. n_splits : int
  1996. Returns the number of splitting iterations in the cross-validator.
  1997. """
  1998. return len(self.unique_folds)
  1999. class _CVIterableWrapper(BaseCrossValidator):
  2000. """Wrapper class for old style cv objects and iterables."""
  2001. def __init__(self, cv):
  2002. self.cv = list(cv)
  2003. def get_n_splits(self, X=None, y=None, groups=None):
  2004. """Returns the number of splitting iterations in the cross-validator
  2005. Parameters
  2006. ----------
  2007. X : object
  2008. Always ignored, exists for compatibility.
  2009. y : object
  2010. Always ignored, exists for compatibility.
  2011. groups : object
  2012. Always ignored, exists for compatibility.
  2013. Returns
  2014. -------
  2015. n_splits : int
  2016. Returns the number of splitting iterations in the cross-validator.
  2017. """
  2018. return len(self.cv)
  2019. def split(self, X=None, y=None, groups=None):
  2020. """Generate indices to split data into training and test set.
  2021. Parameters
  2022. ----------
  2023. X : object
  2024. Always ignored, exists for compatibility.
  2025. y : object
  2026. Always ignored, exists for compatibility.
  2027. groups : object
  2028. Always ignored, exists for compatibility.
  2029. Yields
  2030. ------
  2031. train : ndarray
  2032. The training set indices for that split.
  2033. test : ndarray
  2034. The testing set indices for that split.
  2035. """
  2036. for train, test in self.cv:
  2037. yield train, test
  2038. def check_cv(cv=5, y=None, *, classifier=False):
  2039. """Input checker utility for building a cross-validator.
  2040. Parameters
  2041. ----------
  2042. cv : int, cross-validation generator or an iterable, default=None
  2043. Determines the cross-validation splitting strategy.
  2044. Possible inputs for cv are:
  2045. - None, to use the default 5-fold cross validation,
  2046. - integer, to specify the number of folds.
  2047. - :term:`CV splitter`,
  2048. - An iterable that generates (train, test) splits as arrays of indices.
  2049. For integer/None inputs, if classifier is True and ``y`` is either
  2050. binary or multiclass, :class:`StratifiedKFold` is used. In all other
  2051. cases, :class:`KFold` is used.
  2052. Refer :ref:`User Guide <cross_validation>` for the various
  2053. cross-validation strategies that can be used here.
  2054. .. versionchanged:: 0.22
  2055. ``cv`` default value changed from 3-fold to 5-fold.
  2056. y : array-like, default=None
  2057. The target variable for supervised learning problems.
  2058. classifier : bool, default=False
  2059. Whether the task is a classification task, in which case
  2060. stratified KFold will be used.
  2061. Returns
  2062. -------
  2063. checked_cv : a cross-validator instance.
  2064. The return value is a cross-validator which generates the train/test
  2065. splits via the ``split`` method.
  2066. """
  2067. cv = 5 if cv is None else cv
  2068. if isinstance(cv, numbers.Integral):
  2069. if (
  2070. classifier
  2071. and (y is not None)
  2072. and (type_of_target(y, input_name="y") in ("binary", "multiclass"))
  2073. ):
  2074. return StratifiedKFold(cv)
  2075. else:
  2076. return KFold(cv)
  2077. if not hasattr(cv, "split") or isinstance(cv, str):
  2078. if not isinstance(cv, Iterable) or isinstance(cv, str):
  2079. raise ValueError(
  2080. "Expected cv as an integer, cross-validation "
  2081. "object (from sklearn.model_selection) "
  2082. "or an iterable. Got %s." % cv
  2083. )
  2084. return _CVIterableWrapper(cv)
  2085. return cv # New style cv objects are passed without any modification
  2086. @validate_params(
  2087. {
  2088. "test_size": [
  2089. Interval(RealNotInt, 0, 1, closed="neither"),
  2090. Interval(numbers.Integral, 1, None, closed="left"),
  2091. None,
  2092. ],
  2093. "train_size": [
  2094. Interval(RealNotInt, 0, 1, closed="neither"),
  2095. Interval(numbers.Integral, 1, None, closed="left"),
  2096. None,
  2097. ],
  2098. "random_state": ["random_state"],
  2099. "shuffle": ["boolean"],
  2100. "stratify": ["array-like", None],
  2101. },
  2102. prefer_skip_nested_validation=True,
  2103. )
  2104. def train_test_split(
  2105. *arrays,
  2106. test_size=None,
  2107. train_size=None,
  2108. random_state=None,
  2109. shuffle=True,
  2110. stratify=None,
  2111. ):
  2112. """Split arrays or matrices into random train and test subsets.
  2113. Quick utility that wraps input validation,
  2114. ``next(ShuffleSplit().split(X, y))``, and application to input data
  2115. into a single call for splitting (and optionally subsampling) data into a
  2116. one-liner.
  2117. Read more in the :ref:`User Guide <cross_validation>`.
  2118. Parameters
  2119. ----------
  2120. *arrays : sequence of indexables with same length / shape[0]
  2121. Allowed inputs are lists, numpy arrays, scipy-sparse
  2122. matrices or pandas dataframes.
  2123. test_size : float or int, default=None
  2124. If float, should be between 0.0 and 1.0 and represent the proportion
  2125. of the dataset to include in the test split. If int, represents the
  2126. absolute number of test samples. If None, the value is set to the
  2127. complement of the train size. If ``train_size`` is also None, it will
  2128. be set to 0.25.
  2129. train_size : float or int, default=None
  2130. If float, should be between 0.0 and 1.0 and represent the
  2131. proportion of the dataset to include in the train split. If
  2132. int, represents the absolute number of train samples. If None,
  2133. the value is automatically set to the complement of the test size.
  2134. random_state : int, RandomState instance or None, default=None
  2135. Controls the shuffling applied to the data before applying the split.
  2136. Pass an int for reproducible output across multiple function calls.
  2137. See :term:`Glossary <random_state>`.
  2138. shuffle : bool, default=True
  2139. Whether or not to shuffle the data before splitting. If shuffle=False
  2140. then stratify must be None.
  2141. stratify : array-like, default=None
  2142. If not None, data is split in a stratified fashion, using this as
  2143. the class labels.
  2144. Read more in the :ref:`User Guide <stratification>`.
  2145. Returns
  2146. -------
  2147. splitting : list, length=2 * len(arrays)
  2148. List containing train-test split of inputs.
  2149. .. versionadded:: 0.16
  2150. If the input is sparse, the output will be a
  2151. ``scipy.sparse.csr_matrix``. Else, output type is the same as the
  2152. input type.
  2153. Examples
  2154. --------
  2155. >>> import numpy as np
  2156. >>> from sklearn.model_selection import train_test_split
  2157. >>> X, y = np.arange(10).reshape((5, 2)), range(5)
  2158. >>> X
  2159. array([[0, 1],
  2160. [2, 3],
  2161. [4, 5],
  2162. [6, 7],
  2163. [8, 9]])
  2164. >>> list(y)
  2165. [0, 1, 2, 3, 4]
  2166. >>> X_train, X_test, y_train, y_test = train_test_split(
  2167. ... X, y, test_size=0.33, random_state=42)
  2168. ...
  2169. >>> X_train
  2170. array([[4, 5],
  2171. [0, 1],
  2172. [6, 7]])
  2173. >>> y_train
  2174. [2, 0, 3]
  2175. >>> X_test
  2176. array([[2, 3],
  2177. [8, 9]])
  2178. >>> y_test
  2179. [1, 4]
  2180. >>> train_test_split(y, shuffle=False)
  2181. [[0, 1, 2], [3, 4]]
  2182. """
  2183. n_arrays = len(arrays)
  2184. if n_arrays == 0:
  2185. raise ValueError("At least one array required as input")
  2186. arrays = indexable(*arrays)
  2187. n_samples = _num_samples(arrays[0])
  2188. n_train, n_test = _validate_shuffle_split(
  2189. n_samples, test_size, train_size, default_test_size=0.25
  2190. )
  2191. if shuffle is False:
  2192. if stratify is not None:
  2193. raise ValueError(
  2194. "Stratified train/test split is not implemented for shuffle=False"
  2195. )
  2196. train = np.arange(n_train)
  2197. test = np.arange(n_train, n_train + n_test)
  2198. else:
  2199. if stratify is not None:
  2200. CVClass = StratifiedShuffleSplit
  2201. else:
  2202. CVClass = ShuffleSplit
  2203. cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
  2204. train, test = next(cv.split(X=arrays[0], y=stratify))
  2205. return list(
  2206. chain.from_iterable(
  2207. (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays
  2208. )
  2209. )
  2210. # Tell nose that train_test_split is not a test.
  2211. # (Needed for external libraries that may use nose.)
  2212. # Use setattr to avoid mypy errors when monkeypatching.
  2213. setattr(train_test_split, "__test__", False)
  2214. def _pprint(params, offset=0, printer=repr):
  2215. """Pretty print the dictionary 'params'
  2216. Parameters
  2217. ----------
  2218. params : dict
  2219. The dictionary to pretty print
  2220. offset : int, default=0
  2221. The offset in characters to add at the begin of each line.
  2222. printer : callable, default=repr
  2223. The function to convert entries to strings, typically
  2224. the builtin str or repr
  2225. """
  2226. # Do a multi-line justified repr:
  2227. options = np.get_printoptions()
  2228. np.set_printoptions(precision=5, threshold=64, edgeitems=2)
  2229. params_list = list()
  2230. this_line_length = offset
  2231. line_sep = ",\n" + (1 + offset // 2) * " "
  2232. for i, (k, v) in enumerate(sorted(params.items())):
  2233. if isinstance(v, float):
  2234. # use str for representing floating point numbers
  2235. # this way we get consistent representation across
  2236. # architectures and versions.
  2237. this_repr = "%s=%s" % (k, str(v))
  2238. else:
  2239. # use repr of the rest
  2240. this_repr = "%s=%s" % (k, printer(v))
  2241. if len(this_repr) > 500:
  2242. this_repr = this_repr[:300] + "..." + this_repr[-100:]
  2243. if i > 0:
  2244. if this_line_length + len(this_repr) >= 75 or "\n" in this_repr:
  2245. params_list.append(line_sep)
  2246. this_line_length = len(line_sep)
  2247. else:
  2248. params_list.append(", ")
  2249. this_line_length += 2
  2250. params_list.append(this_repr)
  2251. this_line_length += len(this_repr)
  2252. np.set_printoptions(**options)
  2253. lines = "".join(params_list)
  2254. # Strip trailing space to avoid nightmare in doctests
  2255. lines = "\n".join(l.rstrip(" ") for l in lines.split("\n"))
  2256. return lines
  2257. def _build_repr(self):
  2258. # XXX This is copied from BaseEstimator's get_params
  2259. cls = self.__class__
  2260. init = getattr(cls.__init__, "deprecated_original", cls.__init__)
  2261. # Ignore varargs, kw and default values and pop self
  2262. init_signature = signature(init)
  2263. # Consider the constructor parameters excluding 'self'
  2264. if init is object.__init__:
  2265. args = []
  2266. else:
  2267. args = sorted(
  2268. [
  2269. p.name
  2270. for p in init_signature.parameters.values()
  2271. if p.name != "self" and p.kind != p.VAR_KEYWORD
  2272. ]
  2273. )
  2274. class_name = self.__class__.__name__
  2275. params = dict()
  2276. for key in args:
  2277. # We need deprecation warnings to always be on in order to
  2278. # catch deprecated param values.
  2279. # This is set in utils/__init__.py but it gets overwritten
  2280. # when running under python3 somehow.
  2281. warnings.simplefilter("always", FutureWarning)
  2282. try:
  2283. with warnings.catch_warnings(record=True) as w:
  2284. value = getattr(self, key, None)
  2285. if value is None and hasattr(self, "cvargs"):
  2286. value = self.cvargs.get(key, None)
  2287. if len(w) and w[0].category == FutureWarning:
  2288. # if the parameter is deprecated, don't show it
  2289. continue
  2290. finally:
  2291. warnings.filters.pop(0)
  2292. params[key] = value
  2293. return "%s(%s)" % (class_name, _pprint(params, offset=len(class_name)))
  2294. def _yields_constant_splits(cv):
  2295. # Return True if calling cv.split() always returns the same splits
  2296. # We assume that if a cv doesn't have a shuffle parameter, it shuffles by
  2297. # default (e.g. ShuffleSplit). If it actually doesn't shuffle (e.g.
  2298. # LeaveOneOut), then it won't have a random_state parameter anyway, in
  2299. # which case it will default to 0, leading to output=True
  2300. shuffle = getattr(cv, "shuffle", True)
  2301. random_state = getattr(cv, "random_state", 0)
  2302. return isinstance(random_state, numbers.Integral) or not shuffle