test_validation.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412
  1. """Test the validation module"""
  2. import os
  3. import re
  4. import sys
  5. import tempfile
  6. import warnings
  7. from functools import partial
  8. from io import StringIO
  9. from time import sleep
  10. import numpy as np
  11. import pytest
  12. from scipy.sparse import coo_matrix, csr_matrix, issparse
  13. from sklearn.base import BaseEstimator, clone
  14. from sklearn.cluster import KMeans
  15. from sklearn.datasets import (
  16. load_diabetes,
  17. load_digits,
  18. load_iris,
  19. make_classification,
  20. make_multilabel_classification,
  21. make_regression,
  22. )
  23. from sklearn.ensemble import RandomForestClassifier
  24. from sklearn.exceptions import FitFailedWarning
  25. from sklearn.impute import SimpleImputer
  26. from sklearn.linear_model import (
  27. LogisticRegression,
  28. PassiveAggressiveClassifier,
  29. Ridge,
  30. RidgeClassifier,
  31. SGDClassifier,
  32. )
  33. from sklearn.metrics import (
  34. accuracy_score,
  35. check_scoring,
  36. confusion_matrix,
  37. explained_variance_score,
  38. make_scorer,
  39. mean_squared_error,
  40. precision_recall_fscore_support,
  41. precision_score,
  42. r2_score,
  43. )
  44. from sklearn.model_selection import (
  45. GridSearchCV,
  46. GroupKFold,
  47. GroupShuffleSplit,
  48. KFold,
  49. LeaveOneGroupOut,
  50. LeaveOneOut,
  51. LeavePGroupsOut,
  52. ShuffleSplit,
  53. StratifiedKFold,
  54. cross_val_predict,
  55. cross_val_score,
  56. cross_validate,
  57. learning_curve,
  58. permutation_test_score,
  59. validation_curve,
  60. )
  61. from sklearn.model_selection._validation import (
  62. _check_is_permutation,
  63. _fit_and_score,
  64. _score,
  65. )
  66. from sklearn.model_selection.tests.common import OneTimeSplitter
  67. from sklearn.model_selection.tests.test_search import FailingClassifier
  68. from sklearn.multiclass import OneVsRestClassifier
  69. from sklearn.neighbors import KNeighborsClassifier
  70. from sklearn.neural_network import MLPRegressor
  71. from sklearn.pipeline import Pipeline
  72. from sklearn.preprocessing import LabelEncoder, scale
  73. from sklearn.svm import SVC, LinearSVC
  74. from sklearn.utils import shuffle
  75. from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
  76. from sklearn.utils._testing import (
  77. assert_allclose,
  78. assert_almost_equal,
  79. assert_array_almost_equal,
  80. assert_array_equal,
  81. )
  82. from sklearn.utils.validation import _num_samples
  83. class MockImprovingEstimator(BaseEstimator):
  84. """Dummy classifier to test the learning curve"""
  85. def __init__(self, n_max_train_sizes):
  86. self.n_max_train_sizes = n_max_train_sizes
  87. self.train_sizes = 0
  88. self.X_subset = None
  89. def fit(self, X_subset, y_subset=None):
  90. self.X_subset = X_subset
  91. self.train_sizes = X_subset.shape[0]
  92. return self
  93. def predict(self, X):
  94. raise NotImplementedError
  95. def score(self, X=None, Y=None):
  96. # training score becomes worse (2 -> 1), test error better (0 -> 1)
  97. if self._is_training_data(X):
  98. return 2.0 - float(self.train_sizes) / self.n_max_train_sizes
  99. else:
  100. return float(self.train_sizes) / self.n_max_train_sizes
  101. def _is_training_data(self, X):
  102. return X is self.X_subset
  103. class MockIncrementalImprovingEstimator(MockImprovingEstimator):
  104. """Dummy classifier that provides partial_fit"""
  105. def __init__(self, n_max_train_sizes, expected_fit_params=None):
  106. super().__init__(n_max_train_sizes)
  107. self.x = None
  108. self.expected_fit_params = expected_fit_params
  109. def _is_training_data(self, X):
  110. return self.x in X
  111. def partial_fit(self, X, y=None, **params):
  112. self.train_sizes += X.shape[0]
  113. self.x = X[0]
  114. if self.expected_fit_params:
  115. missing = set(self.expected_fit_params) - set(params)
  116. if missing:
  117. raise AssertionError(
  118. f"Expected fit parameter(s) {list(missing)} not seen."
  119. )
  120. for key, value in params.items():
  121. if key in self.expected_fit_params and _num_samples(
  122. value
  123. ) != _num_samples(X):
  124. raise AssertionError(
  125. f"Fit parameter {key} has length {_num_samples(value)}"
  126. f"; expected {_num_samples(X)}."
  127. )
  128. class MockEstimatorWithParameter(BaseEstimator):
  129. """Dummy classifier to test the validation curve"""
  130. def __init__(self, param=0.5):
  131. self.X_subset = None
  132. self.param = param
  133. def fit(self, X_subset, y_subset):
  134. self.X_subset = X_subset
  135. self.train_sizes = X_subset.shape[0]
  136. return self
  137. def predict(self, X):
  138. raise NotImplementedError
  139. def score(self, X=None, y=None):
  140. return self.param if self._is_training_data(X) else 1 - self.param
  141. def _is_training_data(self, X):
  142. return X is self.X_subset
  143. class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
  144. """Dummy classifier that disallows repeated calls of fit method"""
  145. def fit(self, X_subset, y_subset):
  146. assert not hasattr(self, "fit_called_"), "fit is called the second time"
  147. self.fit_called_ = True
  148. return super().fit(X_subset, y_subset)
  149. def predict(self, X):
  150. raise NotImplementedError
  151. class MockClassifier:
  152. """Dummy classifier to test the cross-validation"""
  153. def __init__(self, a=0, allow_nd=False):
  154. self.a = a
  155. self.allow_nd = allow_nd
  156. def fit(
  157. self,
  158. X,
  159. Y=None,
  160. sample_weight=None,
  161. class_prior=None,
  162. sparse_sample_weight=None,
  163. sparse_param=None,
  164. dummy_int=None,
  165. dummy_str=None,
  166. dummy_obj=None,
  167. callback=None,
  168. ):
  169. """The dummy arguments are to test that this fit function can
  170. accept non-array arguments through cross-validation, such as:
  171. - int
  172. - str (this is actually array-like)
  173. - object
  174. - function
  175. """
  176. self.dummy_int = dummy_int
  177. self.dummy_str = dummy_str
  178. self.dummy_obj = dummy_obj
  179. if callback is not None:
  180. callback(self)
  181. if self.allow_nd:
  182. X = X.reshape(len(X), -1)
  183. if X.ndim >= 3 and not self.allow_nd:
  184. raise ValueError("X cannot be d")
  185. if sample_weight is not None:
  186. assert sample_weight.shape[0] == X.shape[0], (
  187. "MockClassifier extra fit_param "
  188. "sample_weight.shape[0] is {0}, should be {1}".format(
  189. sample_weight.shape[0], X.shape[0]
  190. )
  191. )
  192. if class_prior is not None:
  193. assert class_prior.shape[0] == len(np.unique(y)), (
  194. "MockClassifier extra fit_param class_prior.shape[0]"
  195. " is {0}, should be {1}".format(class_prior.shape[0], len(np.unique(y)))
  196. )
  197. if sparse_sample_weight is not None:
  198. fmt = (
  199. "MockClassifier extra fit_param sparse_sample_weight"
  200. ".shape[0] is {0}, should be {1}"
  201. )
  202. assert sparse_sample_weight.shape[0] == X.shape[0], fmt.format(
  203. sparse_sample_weight.shape[0], X.shape[0]
  204. )
  205. if sparse_param is not None:
  206. fmt = (
  207. "MockClassifier extra fit_param sparse_param.shape "
  208. "is ({0}, {1}), should be ({2}, {3})"
  209. )
  210. assert sparse_param.shape == P_sparse.shape, fmt.format(
  211. sparse_param.shape[0],
  212. sparse_param.shape[1],
  213. P_sparse.shape[0],
  214. P_sparse.shape[1],
  215. )
  216. return self
  217. def predict(self, T):
  218. if self.allow_nd:
  219. T = T.reshape(len(T), -1)
  220. return T[:, 0]
  221. def predict_proba(self, T):
  222. return T
  223. def score(self, X=None, Y=None):
  224. return 1.0 / (1 + np.abs(self.a))
  225. def get_params(self, deep=False):
  226. return {"a": self.a, "allow_nd": self.allow_nd}
  227. # XXX: use 2D array, since 1D X is being detected as a single sample in
  228. # check_consistent_length
  229. X = np.ones((10, 2))
  230. X_sparse = coo_matrix(X)
  231. y = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
  232. # The number of samples per class needs to be > n_splits,
  233. # for StratifiedKFold(n_splits=3)
  234. y2 = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 3])
  235. P_sparse = coo_matrix(np.eye(5))
  236. def test_cross_val_score():
  237. clf = MockClassifier()
  238. for a in range(-10, 10):
  239. clf.a = a
  240. # Smoke test
  241. scores = cross_val_score(clf, X, y2)
  242. assert_array_equal(scores, clf.score(X, y2))
  243. # test with multioutput y
  244. multioutput_y = np.column_stack([y2, y2[::-1]])
  245. scores = cross_val_score(clf, X_sparse, multioutput_y)
  246. assert_array_equal(scores, clf.score(X_sparse, multioutput_y))
  247. scores = cross_val_score(clf, X_sparse, y2)
  248. assert_array_equal(scores, clf.score(X_sparse, y2))
  249. # test with multioutput y
  250. scores = cross_val_score(clf, X_sparse, multioutput_y)
  251. assert_array_equal(scores, clf.score(X_sparse, multioutput_y))
  252. # test with X and y as list
  253. list_check = lambda x: isinstance(x, list)
  254. clf = CheckingClassifier(check_X=list_check)
  255. scores = cross_val_score(clf, X.tolist(), y2.tolist(), cv=3)
  256. clf = CheckingClassifier(check_y=list_check)
  257. scores = cross_val_score(clf, X, y2.tolist(), cv=3)
  258. with pytest.raises(ValueError):
  259. cross_val_score(clf, X, y2, scoring="sklearn")
  260. # test with 3d X and
  261. X_3d = X[:, :, np.newaxis]
  262. clf = MockClassifier(allow_nd=True)
  263. scores = cross_val_score(clf, X_3d, y2)
  264. clf = MockClassifier(allow_nd=False)
  265. with pytest.raises(ValueError):
  266. cross_val_score(clf, X_3d, y2, error_score="raise")
  267. def test_cross_validate_many_jobs():
  268. # regression test for #12154: cv='warn' with n_jobs>1 trigger a copy of
  269. # the parameters leading to a failure in check_cv due to cv is 'warn'
  270. # instead of cv == 'warn'.
  271. X, y = load_iris(return_X_y=True)
  272. clf = SVC(gamma="auto")
  273. grid = GridSearchCV(clf, param_grid={"C": [1, 10]})
  274. cross_validate(grid, X, y, n_jobs=2)
  275. def test_cross_validate_invalid_scoring_param():
  276. X, y = make_classification(random_state=0)
  277. estimator = MockClassifier()
  278. # Test the errors
  279. error_message_regexp = ".*must be unique strings.*"
  280. # List/tuple of callables should raise a message advising users to use
  281. # dict of names to callables mapping
  282. with pytest.raises(ValueError, match=error_message_regexp):
  283. cross_validate(
  284. estimator,
  285. X,
  286. y,
  287. scoring=(make_scorer(precision_score), make_scorer(accuracy_score)),
  288. )
  289. with pytest.raises(ValueError, match=error_message_regexp):
  290. cross_validate(estimator, X, y, scoring=(make_scorer(precision_score),))
  291. # So should empty lists/tuples
  292. with pytest.raises(ValueError, match=error_message_regexp + "Empty list.*"):
  293. cross_validate(estimator, X, y, scoring=())
  294. # So should duplicated entries
  295. with pytest.raises(ValueError, match=error_message_regexp + "Duplicate.*"):
  296. cross_validate(estimator, X, y, scoring=("f1_micro", "f1_micro"))
  297. # Nested Lists should raise a generic error message
  298. with pytest.raises(ValueError, match=error_message_regexp):
  299. cross_validate(estimator, X, y, scoring=[[make_scorer(precision_score)]])
  300. # Empty dict should raise invalid scoring error
  301. with pytest.raises(ValueError, match="An empty dict"):
  302. cross_validate(estimator, X, y, scoring=(dict()))
  303. multiclass_scorer = make_scorer(precision_recall_fscore_support)
  304. # Multiclass Scorers that return multiple values are not supported yet
  305. # the warning message we're expecting to see
  306. warning_message = (
  307. "Scoring failed. The score on this train-test "
  308. f"partition for these parameters will be set to {np.nan}. "
  309. "Details: \n"
  310. )
  311. with pytest.warns(UserWarning, match=warning_message):
  312. cross_validate(estimator, X, y, scoring=multiclass_scorer)
  313. with pytest.warns(UserWarning, match=warning_message):
  314. cross_validate(estimator, X, y, scoring={"foo": multiclass_scorer})
  315. def test_cross_validate_nested_estimator():
  316. # Non-regression test to ensure that nested
  317. # estimators are properly returned in a list
  318. # https://github.com/scikit-learn/scikit-learn/pull/17745
  319. (X, y) = load_iris(return_X_y=True)
  320. pipeline = Pipeline(
  321. [
  322. ("imputer", SimpleImputer()),
  323. ("classifier", MockClassifier()),
  324. ]
  325. )
  326. results = cross_validate(pipeline, X, y, return_estimator=True)
  327. estimators = results["estimator"]
  328. assert isinstance(estimators, list)
  329. assert all(isinstance(estimator, Pipeline) for estimator in estimators)
  330. @pytest.mark.parametrize("use_sparse", [False, True])
  331. def test_cross_validate(use_sparse: bool):
  332. # Compute train and test mse/r2 scores
  333. cv = KFold()
  334. # Regression
  335. X_reg, y_reg = make_regression(n_samples=30, random_state=0)
  336. reg = Ridge(random_state=0)
  337. # Classification
  338. X_clf, y_clf = make_classification(n_samples=30, random_state=0)
  339. clf = SVC(kernel="linear", random_state=0)
  340. if use_sparse:
  341. X_reg = csr_matrix(X_reg)
  342. X_clf = csr_matrix(X_clf)
  343. for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)):
  344. # It's okay to evaluate regression metrics on classification too
  345. mse_scorer = check_scoring(est, scoring="neg_mean_squared_error")
  346. r2_scorer = check_scoring(est, scoring="r2")
  347. train_mse_scores = []
  348. test_mse_scores = []
  349. train_r2_scores = []
  350. test_r2_scores = []
  351. fitted_estimators = []
  352. for train, test in cv.split(X, y):
  353. est = clone(est).fit(X[train], y[train])
  354. train_mse_scores.append(mse_scorer(est, X[train], y[train]))
  355. train_r2_scores.append(r2_scorer(est, X[train], y[train]))
  356. test_mse_scores.append(mse_scorer(est, X[test], y[test]))
  357. test_r2_scores.append(r2_scorer(est, X[test], y[test]))
  358. fitted_estimators.append(est)
  359. train_mse_scores = np.array(train_mse_scores)
  360. test_mse_scores = np.array(test_mse_scores)
  361. train_r2_scores = np.array(train_r2_scores)
  362. test_r2_scores = np.array(test_r2_scores)
  363. fitted_estimators = np.array(fitted_estimators)
  364. scores = (
  365. train_mse_scores,
  366. test_mse_scores,
  367. train_r2_scores,
  368. test_r2_scores,
  369. fitted_estimators,
  370. )
  371. # To ensure that the test does not suffer from
  372. # large statistical fluctuations due to slicing small datasets,
  373. # we pass the cross-validation instance
  374. check_cross_validate_single_metric(est, X, y, scores, cv)
  375. check_cross_validate_multi_metric(est, X, y, scores, cv)
  376. def check_cross_validate_single_metric(clf, X, y, scores, cv):
  377. (
  378. train_mse_scores,
  379. test_mse_scores,
  380. train_r2_scores,
  381. test_r2_scores,
  382. fitted_estimators,
  383. ) = scores
  384. # Test single metric evaluation when scoring is string or singleton list
  385. for return_train_score, dict_len in ((True, 4), (False, 3)):
  386. # Single metric passed as a string
  387. if return_train_score:
  388. mse_scores_dict = cross_validate(
  389. clf,
  390. X,
  391. y,
  392. scoring="neg_mean_squared_error",
  393. return_train_score=True,
  394. cv=cv,
  395. )
  396. assert_array_almost_equal(mse_scores_dict["train_score"], train_mse_scores)
  397. else:
  398. mse_scores_dict = cross_validate(
  399. clf,
  400. X,
  401. y,
  402. scoring="neg_mean_squared_error",
  403. return_train_score=False,
  404. cv=cv,
  405. )
  406. assert isinstance(mse_scores_dict, dict)
  407. assert len(mse_scores_dict) == dict_len
  408. assert_array_almost_equal(mse_scores_dict["test_score"], test_mse_scores)
  409. # Single metric passed as a list
  410. if return_train_score:
  411. # It must be True by default - deprecated
  412. r2_scores_dict = cross_validate(
  413. clf, X, y, scoring=["r2"], return_train_score=True, cv=cv
  414. )
  415. assert_array_almost_equal(r2_scores_dict["train_r2"], train_r2_scores, True)
  416. else:
  417. r2_scores_dict = cross_validate(
  418. clf, X, y, scoring=["r2"], return_train_score=False, cv=cv
  419. )
  420. assert isinstance(r2_scores_dict, dict)
  421. assert len(r2_scores_dict) == dict_len
  422. assert_array_almost_equal(r2_scores_dict["test_r2"], test_r2_scores)
  423. # Test return_estimator option
  424. mse_scores_dict = cross_validate(
  425. clf, X, y, scoring="neg_mean_squared_error", return_estimator=True, cv=cv
  426. )
  427. for k, est in enumerate(mse_scores_dict["estimator"]):
  428. est_coef = est.coef_.copy()
  429. if issparse(est_coef):
  430. est_coef = est_coef.toarray()
  431. fitted_est_coef = fitted_estimators[k].coef_.copy()
  432. if issparse(fitted_est_coef):
  433. fitted_est_coef = fitted_est_coef.toarray()
  434. assert_almost_equal(est_coef, fitted_est_coef)
  435. assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)
  436. def check_cross_validate_multi_metric(clf, X, y, scores, cv):
  437. # Test multimetric evaluation when scoring is a list / dict
  438. (
  439. train_mse_scores,
  440. test_mse_scores,
  441. train_r2_scores,
  442. test_r2_scores,
  443. fitted_estimators,
  444. ) = scores
  445. def custom_scorer(clf, X, y):
  446. y_pred = clf.predict(X)
  447. return {
  448. "r2": r2_score(y, y_pred),
  449. "neg_mean_squared_error": -mean_squared_error(y, y_pred),
  450. }
  451. all_scoring = (
  452. ("r2", "neg_mean_squared_error"),
  453. {
  454. "r2": make_scorer(r2_score),
  455. "neg_mean_squared_error": "neg_mean_squared_error",
  456. },
  457. custom_scorer,
  458. )
  459. keys_sans_train = {
  460. "test_r2",
  461. "test_neg_mean_squared_error",
  462. "fit_time",
  463. "score_time",
  464. }
  465. keys_with_train = keys_sans_train.union(
  466. {"train_r2", "train_neg_mean_squared_error"}
  467. )
  468. for return_train_score in (True, False):
  469. for scoring in all_scoring:
  470. if return_train_score:
  471. # return_train_score must be True by default - deprecated
  472. cv_results = cross_validate(
  473. clf, X, y, scoring=scoring, return_train_score=True, cv=cv
  474. )
  475. assert_array_almost_equal(cv_results["train_r2"], train_r2_scores)
  476. assert_array_almost_equal(
  477. cv_results["train_neg_mean_squared_error"], train_mse_scores
  478. )
  479. else:
  480. cv_results = cross_validate(
  481. clf, X, y, scoring=scoring, return_train_score=False, cv=cv
  482. )
  483. assert isinstance(cv_results, dict)
  484. assert set(cv_results.keys()) == (
  485. keys_with_train if return_train_score else keys_sans_train
  486. )
  487. assert_array_almost_equal(cv_results["test_r2"], test_r2_scores)
  488. assert_array_almost_equal(
  489. cv_results["test_neg_mean_squared_error"], test_mse_scores
  490. )
  491. # Make sure all the arrays are of np.ndarray type
  492. assert type(cv_results["test_r2"]) == np.ndarray
  493. assert type(cv_results["test_neg_mean_squared_error"]) == np.ndarray
  494. assert type(cv_results["fit_time"]) == np.ndarray
  495. assert type(cv_results["score_time"]) == np.ndarray
  496. # Ensure all the times are within sane limits
  497. assert np.all(cv_results["fit_time"] >= 0)
  498. assert np.all(cv_results["fit_time"] < 10)
  499. assert np.all(cv_results["score_time"] >= 0)
  500. assert np.all(cv_results["score_time"] < 10)
  501. def test_cross_val_score_predict_groups():
  502. # Check if ValueError (when groups is None) propagates to cross_val_score
  503. # and cross_val_predict
  504. # And also check if groups is correctly passed to the cv object
  505. X, y = make_classification(n_samples=20, n_classes=2, random_state=0)
  506. clf = SVC(kernel="linear")
  507. group_cvs = [
  508. LeaveOneGroupOut(),
  509. LeavePGroupsOut(2),
  510. GroupKFold(),
  511. GroupShuffleSplit(),
  512. ]
  513. error_message = "The 'groups' parameter should not be None."
  514. for cv in group_cvs:
  515. with pytest.raises(ValueError, match=error_message):
  516. cross_val_score(estimator=clf, X=X, y=y, cv=cv)
  517. with pytest.raises(ValueError, match=error_message):
  518. cross_val_predict(estimator=clf, X=X, y=y, cv=cv)
  519. @pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
  520. def test_cross_val_score_pandas():
  521. # check cross_val_score doesn't destroy pandas dataframe
  522. types = [(MockDataFrame, MockDataFrame)]
  523. try:
  524. from pandas import DataFrame, Series
  525. types.append((Series, DataFrame))
  526. except ImportError:
  527. pass
  528. for TargetType, InputFeatureType in types:
  529. # X dataframe, y series
  530. # 3 fold cross val is used so we need at least 3 samples per class
  531. X_df, y_ser = InputFeatureType(X), TargetType(y2)
  532. check_df = lambda x: isinstance(x, InputFeatureType)
  533. check_series = lambda x: isinstance(x, TargetType)
  534. clf = CheckingClassifier(check_X=check_df, check_y=check_series)
  535. cross_val_score(clf, X_df, y_ser, cv=3)
  536. def test_cross_val_score_mask():
  537. # test that cross_val_score works with boolean masks
  538. svm = SVC(kernel="linear")
  539. iris = load_iris()
  540. X, y = iris.data, iris.target
  541. kfold = KFold(5)
  542. scores_indices = cross_val_score(svm, X, y, cv=kfold)
  543. kfold = KFold(5)
  544. cv_masks = []
  545. for train, test in kfold.split(X, y):
  546. mask_train = np.zeros(len(y), dtype=bool)
  547. mask_test = np.zeros(len(y), dtype=bool)
  548. mask_train[train] = 1
  549. mask_test[test] = 1
  550. cv_masks.append((train, test))
  551. scores_masks = cross_val_score(svm, X, y, cv=cv_masks)
  552. assert_array_equal(scores_indices, scores_masks)
  553. def test_cross_val_score_precomputed():
  554. # test for svm with precomputed kernel
  555. svm = SVC(kernel="precomputed")
  556. iris = load_iris()
  557. X, y = iris.data, iris.target
  558. linear_kernel = np.dot(X, X.T)
  559. score_precomputed = cross_val_score(svm, linear_kernel, y)
  560. svm = SVC(kernel="linear")
  561. score_linear = cross_val_score(svm, X, y)
  562. assert_array_almost_equal(score_precomputed, score_linear)
  563. # test with callable
  564. svm = SVC(kernel=lambda x, y: np.dot(x, y.T))
  565. score_callable = cross_val_score(svm, X, y)
  566. assert_array_almost_equal(score_precomputed, score_callable)
  567. # Error raised for non-square X
  568. svm = SVC(kernel="precomputed")
  569. with pytest.raises(ValueError):
  570. cross_val_score(svm, X, y)
  571. # test error is raised when the precomputed kernel is not array-like
  572. # or sparse
  573. with pytest.raises(ValueError):
  574. cross_val_score(svm, linear_kernel.tolist(), y)
  575. def test_cross_val_score_fit_params():
  576. clf = MockClassifier()
  577. n_samples = X.shape[0]
  578. n_classes = len(np.unique(y))
  579. W_sparse = coo_matrix(
  580. (np.array([1]), (np.array([1]), np.array([0]))), shape=(10, 1)
  581. )
  582. P_sparse = coo_matrix(np.eye(5))
  583. DUMMY_INT = 42
  584. DUMMY_STR = "42"
  585. DUMMY_OBJ = object()
  586. def assert_fit_params(clf):
  587. # Function to test that the values are passed correctly to the
  588. # classifier arguments for non-array type
  589. assert clf.dummy_int == DUMMY_INT
  590. assert clf.dummy_str == DUMMY_STR
  591. assert clf.dummy_obj == DUMMY_OBJ
  592. fit_params = {
  593. "sample_weight": np.ones(n_samples),
  594. "class_prior": np.full(n_classes, 1.0 / n_classes),
  595. "sparse_sample_weight": W_sparse,
  596. "sparse_param": P_sparse,
  597. "dummy_int": DUMMY_INT,
  598. "dummy_str": DUMMY_STR,
  599. "dummy_obj": DUMMY_OBJ,
  600. "callback": assert_fit_params,
  601. }
  602. cross_val_score(clf, X, y, fit_params=fit_params)
  603. def test_cross_val_score_score_func():
  604. clf = MockClassifier()
  605. _score_func_args = []
  606. def score_func(y_test, y_predict):
  607. _score_func_args.append((y_test, y_predict))
  608. return 1.0
  609. with warnings.catch_warnings(record=True):
  610. scoring = make_scorer(score_func)
  611. score = cross_val_score(clf, X, y, scoring=scoring, cv=3)
  612. assert_array_equal(score, [1.0, 1.0, 1.0])
  613. # Test that score function is called only 3 times (for cv=3)
  614. assert len(_score_func_args) == 3
  615. def test_cross_val_score_errors():
  616. class BrokenEstimator:
  617. pass
  618. with pytest.raises(TypeError):
  619. cross_val_score(BrokenEstimator(), X)
  620. def test_cross_val_score_with_score_func_classification():
  621. iris = load_iris()
  622. clf = SVC(kernel="linear")
  623. # Default score (should be the accuracy score)
  624. scores = cross_val_score(clf, iris.data, iris.target)
  625. assert_array_almost_equal(scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
  626. # Correct classification score (aka. zero / one score) - should be the
  627. # same as the default estimator score
  628. zo_scores = cross_val_score(clf, iris.data, iris.target, scoring="accuracy")
  629. assert_array_almost_equal(zo_scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
  630. # F1 score (class are balanced so f1_score should be equal to zero/one
  631. # score
  632. f1_scores = cross_val_score(clf, iris.data, iris.target, scoring="f1_weighted")
  633. assert_array_almost_equal(f1_scores, [0.97, 1.0, 0.97, 0.97, 1.0], 2)
  634. def test_cross_val_score_with_score_func_regression():
  635. X, y = make_regression(n_samples=30, n_features=20, n_informative=5, random_state=0)
  636. reg = Ridge()
  637. # Default score of the Ridge regression estimator
  638. scores = cross_val_score(reg, X, y)
  639. assert_array_almost_equal(scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
  640. # R2 score (aka. determination coefficient) - should be the
  641. # same as the default estimator score
  642. r2_scores = cross_val_score(reg, X, y, scoring="r2")
  643. assert_array_almost_equal(r2_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
  644. # Mean squared error; this is a loss function, so "scores" are negative
  645. neg_mse_scores = cross_val_score(reg, X, y, scoring="neg_mean_squared_error")
  646. expected_neg_mse = np.array([-763.07, -553.16, -274.38, -273.26, -1681.99])
  647. assert_array_almost_equal(neg_mse_scores, expected_neg_mse, 2)
  648. # Explained variance
  649. scoring = make_scorer(explained_variance_score)
  650. ev_scores = cross_val_score(reg, X, y, scoring=scoring)
  651. assert_array_almost_equal(ev_scores, [0.94, 0.97, 0.97, 0.99, 0.92], 2)
  652. def test_permutation_score():
  653. iris = load_iris()
  654. X = iris.data
  655. X_sparse = coo_matrix(X)
  656. y = iris.target
  657. svm = SVC(kernel="linear")
  658. cv = StratifiedKFold(2)
  659. score, scores, pvalue = permutation_test_score(
  660. svm, X, y, n_permutations=30, cv=cv, scoring="accuracy"
  661. )
  662. assert score > 0.9
  663. assert_almost_equal(pvalue, 0.0, 1)
  664. score_group, _, pvalue_group = permutation_test_score(
  665. svm,
  666. X,
  667. y,
  668. n_permutations=30,
  669. cv=cv,
  670. scoring="accuracy",
  671. groups=np.ones(y.size),
  672. random_state=0,
  673. )
  674. assert score_group == score
  675. assert pvalue_group == pvalue
  676. # check that we obtain the same results with a sparse representation
  677. svm_sparse = SVC(kernel="linear")
  678. cv_sparse = StratifiedKFold(2)
  679. score_group, _, pvalue_group = permutation_test_score(
  680. svm_sparse,
  681. X_sparse,
  682. y,
  683. n_permutations=30,
  684. cv=cv_sparse,
  685. scoring="accuracy",
  686. groups=np.ones(y.size),
  687. random_state=0,
  688. )
  689. assert score_group == score
  690. assert pvalue_group == pvalue
  691. # test with custom scoring object
  692. def custom_score(y_true, y_pred):
  693. return ((y_true == y_pred).sum() - (y_true != y_pred).sum()) / y_true.shape[0]
  694. scorer = make_scorer(custom_score)
  695. score, _, pvalue = permutation_test_score(
  696. svm, X, y, n_permutations=100, scoring=scorer, cv=cv, random_state=0
  697. )
  698. assert_almost_equal(score, 0.93, 2)
  699. assert_almost_equal(pvalue, 0.01, 3)
  700. # set random y
  701. y = np.mod(np.arange(len(y)), 3)
  702. score, scores, pvalue = permutation_test_score(
  703. svm, X, y, n_permutations=30, cv=cv, scoring="accuracy"
  704. )
  705. assert score < 0.5
  706. assert pvalue > 0.2
  707. def test_permutation_test_score_allow_nans():
  708. # Check that permutation_test_score allows input data with NaNs
  709. X = np.arange(200, dtype=np.float64).reshape(10, -1)
  710. X[2, :] = np.nan
  711. y = np.repeat([0, 1], X.shape[0] / 2)
  712. p = Pipeline(
  713. [
  714. ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
  715. ("classifier", MockClassifier()),
  716. ]
  717. )
  718. permutation_test_score(p, X, y)
  719. def test_permutation_test_score_fit_params():
  720. X = np.arange(100).reshape(10, 10)
  721. y = np.array([0] * 5 + [1] * 5)
  722. clf = CheckingClassifier(expected_sample_weight=True)
  723. err_msg = r"Expected sample_weight to be passed"
  724. with pytest.raises(AssertionError, match=err_msg):
  725. permutation_test_score(clf, X, y)
  726. err_msg = r"sample_weight.shape == \(1,\), expected \(8,\)!"
  727. with pytest.raises(ValueError, match=err_msg):
  728. permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(1)})
  729. permutation_test_score(clf, X, y, fit_params={"sample_weight": np.ones(10)})
  730. def test_cross_val_score_allow_nans():
  731. # Check that cross_val_score allows input data with NaNs
  732. X = np.arange(200, dtype=np.float64).reshape(10, -1)
  733. X[2, :] = np.nan
  734. y = np.repeat([0, 1], X.shape[0] / 2)
  735. p = Pipeline(
  736. [
  737. ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
  738. ("classifier", MockClassifier()),
  739. ]
  740. )
  741. cross_val_score(p, X, y)
  742. def test_cross_val_score_multilabel():
  743. X = np.array(
  744. [
  745. [-3, 4],
  746. [2, 4],
  747. [3, 3],
  748. [0, 2],
  749. [-3, 1],
  750. [-2, 1],
  751. [0, 0],
  752. [-2, -1],
  753. [-1, -2],
  754. [1, -2],
  755. ]
  756. )
  757. y = np.array(
  758. [[1, 1], [0, 1], [0, 1], [0, 1], [1, 1], [0, 1], [1, 0], [1, 1], [1, 0], [0, 0]]
  759. )
  760. clf = KNeighborsClassifier(n_neighbors=1)
  761. scoring_micro = make_scorer(precision_score, average="micro")
  762. scoring_macro = make_scorer(precision_score, average="macro")
  763. scoring_samples = make_scorer(precision_score, average="samples")
  764. score_micro = cross_val_score(clf, X, y, scoring=scoring_micro)
  765. score_macro = cross_val_score(clf, X, y, scoring=scoring_macro)
  766. score_samples = cross_val_score(clf, X, y, scoring=scoring_samples)
  767. assert_almost_equal(score_micro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 3])
  768. assert_almost_equal(score_macro, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4])
  769. assert_almost_equal(score_samples, [1, 1 / 2, 3 / 4, 1 / 2, 1 / 4])
  770. def test_cross_val_predict():
  771. X, y = load_diabetes(return_X_y=True)
  772. cv = KFold()
  773. est = Ridge()
  774. # Naive loop (should be same as cross_val_predict):
  775. preds2 = np.zeros_like(y)
  776. for train, test in cv.split(X, y):
  777. est.fit(X[train], y[train])
  778. preds2[test] = est.predict(X[test])
  779. preds = cross_val_predict(est, X, y, cv=cv)
  780. assert_array_almost_equal(preds, preds2)
  781. preds = cross_val_predict(est, X, y)
  782. assert len(preds) == len(y)
  783. cv = LeaveOneOut()
  784. preds = cross_val_predict(est, X, y, cv=cv)
  785. assert len(preds) == len(y)
  786. Xsp = X.copy()
  787. Xsp *= Xsp > np.median(Xsp)
  788. Xsp = coo_matrix(Xsp)
  789. preds = cross_val_predict(est, Xsp, y)
  790. assert_array_almost_equal(len(preds), len(y))
  791. preds = cross_val_predict(KMeans(n_init="auto"), X)
  792. assert len(preds) == len(y)
  793. class BadCV:
  794. def split(self, X, y=None, groups=None):
  795. for i in range(4):
  796. yield np.array([0, 1, 2, 3]), np.array([4, 5, 6, 7, 8])
  797. with pytest.raises(ValueError):
  798. cross_val_predict(est, X, y, cv=BadCV())
  799. X, y = load_iris(return_X_y=True)
  800. warning_message = (
  801. r"Number of classes in training fold \(2\) does "
  802. r"not match total number of classes \(3\). "
  803. "Results may not be appropriate for your use case."
  804. )
  805. with pytest.warns(RuntimeWarning, match=warning_message):
  806. cross_val_predict(
  807. LogisticRegression(solver="liblinear"),
  808. X,
  809. y,
  810. method="predict_proba",
  811. cv=KFold(2),
  812. )
  813. def test_cross_val_predict_decision_function_shape():
  814. X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
  815. preds = cross_val_predict(
  816. LogisticRegression(solver="liblinear"), X, y, method="decision_function"
  817. )
  818. assert preds.shape == (50,)
  819. X, y = load_iris(return_X_y=True)
  820. preds = cross_val_predict(
  821. LogisticRegression(solver="liblinear"), X, y, method="decision_function"
  822. )
  823. assert preds.shape == (150, 3)
  824. # This specifically tests imbalanced splits for binary
  825. # classification with decision_function. This is only
  826. # applicable to classifiers that can be fit on a single
  827. # class.
  828. X = X[:100]
  829. y = y[:100]
  830. error_message = (
  831. "Only 1 class/es in training fold,"
  832. " but 2 in overall dataset. This"
  833. " is not supported for decision_function"
  834. " with imbalanced folds. To fix "
  835. "this, use a cross-validation technique "
  836. "resulting in properly stratified folds"
  837. )
  838. with pytest.raises(ValueError, match=error_message):
  839. cross_val_predict(
  840. RidgeClassifier(), X, y, method="decision_function", cv=KFold(2)
  841. )
  842. X, y = load_digits(return_X_y=True)
  843. est = SVC(kernel="linear", decision_function_shape="ovo")
  844. preds = cross_val_predict(est, X, y, method="decision_function")
  845. assert preds.shape == (1797, 45)
  846. ind = np.argsort(y)
  847. X, y = X[ind], y[ind]
  848. error_message_regexp = (
  849. r"Output shape \(599L?, 21L?\) of "
  850. "decision_function does not match number of "
  851. r"classes \(7\) in fold. Irregular "
  852. "decision_function .*"
  853. )
  854. with pytest.raises(ValueError, match=error_message_regexp):
  855. cross_val_predict(est, X, y, cv=KFold(n_splits=3), method="decision_function")
  856. def test_cross_val_predict_predict_proba_shape():
  857. X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
  858. preds = cross_val_predict(
  859. LogisticRegression(solver="liblinear"), X, y, method="predict_proba"
  860. )
  861. assert preds.shape == (50, 2)
  862. X, y = load_iris(return_X_y=True)
  863. preds = cross_val_predict(
  864. LogisticRegression(solver="liblinear"), X, y, method="predict_proba"
  865. )
  866. assert preds.shape == (150, 3)
  867. def test_cross_val_predict_predict_log_proba_shape():
  868. X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
  869. preds = cross_val_predict(
  870. LogisticRegression(solver="liblinear"), X, y, method="predict_log_proba"
  871. )
  872. assert preds.shape == (50, 2)
  873. X, y = load_iris(return_X_y=True)
  874. preds = cross_val_predict(
  875. LogisticRegression(solver="liblinear"), X, y, method="predict_log_proba"
  876. )
  877. assert preds.shape == (150, 3)
  878. def test_cross_val_predict_input_types():
  879. iris = load_iris()
  880. X, y = iris.data, iris.target
  881. X_sparse = coo_matrix(X)
  882. multioutput_y = np.column_stack([y, y[::-1]])
  883. clf = Ridge(fit_intercept=False, random_state=0)
  884. # 3 fold cv is used --> at least 3 samples per class
  885. # Smoke test
  886. predictions = cross_val_predict(clf, X, y)
  887. assert predictions.shape == (150,)
  888. # test with multioutput y
  889. predictions = cross_val_predict(clf, X_sparse, multioutput_y)
  890. assert predictions.shape == (150, 2)
  891. predictions = cross_val_predict(clf, X_sparse, y)
  892. assert_array_equal(predictions.shape, (150,))
  893. # test with multioutput y
  894. predictions = cross_val_predict(clf, X_sparse, multioutput_y)
  895. assert_array_equal(predictions.shape, (150, 2))
  896. # test with X and y as list
  897. list_check = lambda x: isinstance(x, list)
  898. clf = CheckingClassifier(check_X=list_check)
  899. predictions = cross_val_predict(clf, X.tolist(), y.tolist())
  900. clf = CheckingClassifier(check_y=list_check)
  901. predictions = cross_val_predict(clf, X, y.tolist())
  902. # test with X and y as list and non empty method
  903. predictions = cross_val_predict(
  904. LogisticRegression(solver="liblinear"),
  905. X.tolist(),
  906. y.tolist(),
  907. method="decision_function",
  908. )
  909. predictions = cross_val_predict(
  910. LogisticRegression(solver="liblinear"),
  911. X,
  912. y.tolist(),
  913. method="decision_function",
  914. )
  915. # test with 3d X and
  916. X_3d = X[:, :, np.newaxis]
  917. check_3d = lambda x: x.ndim == 3
  918. clf = CheckingClassifier(check_X=check_3d)
  919. predictions = cross_val_predict(clf, X_3d, y)
  920. assert_array_equal(predictions.shape, (150,))
  921. @pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
  922. # python3.7 deprecation warnings in pandas via matplotlib :-/
  923. def test_cross_val_predict_pandas():
  924. # check cross_val_score doesn't destroy pandas dataframe
  925. types = [(MockDataFrame, MockDataFrame)]
  926. try:
  927. from pandas import DataFrame, Series
  928. types.append((Series, DataFrame))
  929. except ImportError:
  930. pass
  931. for TargetType, InputFeatureType in types:
  932. # X dataframe, y series
  933. X_df, y_ser = InputFeatureType(X), TargetType(y2)
  934. check_df = lambda x: isinstance(x, InputFeatureType)
  935. check_series = lambda x: isinstance(x, TargetType)
  936. clf = CheckingClassifier(check_X=check_df, check_y=check_series)
  937. cross_val_predict(clf, X_df, y_ser, cv=3)
  938. def test_cross_val_predict_unbalanced():
  939. X, y = make_classification(
  940. n_samples=100,
  941. n_features=2,
  942. n_redundant=0,
  943. n_informative=2,
  944. n_clusters_per_class=1,
  945. random_state=1,
  946. )
  947. # Change the first sample to a new class
  948. y[0] = 2
  949. clf = LogisticRegression(random_state=1, solver="liblinear")
  950. cv = StratifiedKFold(n_splits=2)
  951. train, test = list(cv.split(X, y))
  952. yhat_proba = cross_val_predict(clf, X, y, cv=cv, method="predict_proba")
  953. assert y[test[0]][0] == 2 # sanity check for further assertions
  954. assert np.all(yhat_proba[test[0]][:, 2] == 0)
  955. assert np.all(yhat_proba[test[0]][:, 0:1] > 0)
  956. assert np.all(yhat_proba[test[1]] > 0)
  957. assert_array_almost_equal(yhat_proba.sum(axis=1), np.ones(y.shape), decimal=12)
  958. def test_cross_val_predict_y_none():
  959. # ensure that cross_val_predict works when y is None
  960. mock_classifier = MockClassifier()
  961. rng = np.random.RandomState(42)
  962. X = rng.rand(100, 10)
  963. y_hat = cross_val_predict(mock_classifier, X, y=None, cv=5, method="predict")
  964. assert_allclose(X[:, 0], y_hat)
  965. y_hat_proba = cross_val_predict(
  966. mock_classifier, X, y=None, cv=5, method="predict_proba"
  967. )
  968. assert_allclose(X, y_hat_proba)
  969. def test_cross_val_score_sparse_fit_params():
  970. iris = load_iris()
  971. X, y = iris.data, iris.target
  972. clf = MockClassifier()
  973. fit_params = {"sparse_sample_weight": coo_matrix(np.eye(X.shape[0]))}
  974. a = cross_val_score(clf, X, y, fit_params=fit_params, cv=3)
  975. assert_array_equal(a, np.ones(3))
  976. def test_learning_curve():
  977. n_samples = 30
  978. n_splits = 3
  979. X, y = make_classification(
  980. n_samples=n_samples,
  981. n_features=1,
  982. n_informative=1,
  983. n_redundant=0,
  984. n_classes=2,
  985. n_clusters_per_class=1,
  986. random_state=0,
  987. )
  988. estimator = MockImprovingEstimator(n_samples * ((n_splits - 1) / n_splits))
  989. for shuffle_train in [False, True]:
  990. with warnings.catch_warnings(record=True) as w:
  991. (
  992. train_sizes,
  993. train_scores,
  994. test_scores,
  995. fit_times,
  996. score_times,
  997. ) = learning_curve(
  998. estimator,
  999. X,
  1000. y,
  1001. cv=KFold(n_splits=n_splits),
  1002. train_sizes=np.linspace(0.1, 1.0, 10),
  1003. shuffle=shuffle_train,
  1004. return_times=True,
  1005. )
  1006. if len(w) > 0:
  1007. raise RuntimeError("Unexpected warning: %r" % w[0].message)
  1008. assert train_scores.shape == (10, 3)
  1009. assert test_scores.shape == (10, 3)
  1010. assert fit_times.shape == (10, 3)
  1011. assert score_times.shape == (10, 3)
  1012. assert_array_equal(train_sizes, np.linspace(2, 20, 10))
  1013. assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
  1014. assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
  1015. # Cannot use assert_array_almost_equal for fit and score times because
  1016. # the values are hardware-dependant
  1017. assert fit_times.dtype == "float64"
  1018. assert score_times.dtype == "float64"
  1019. # Test a custom cv splitter that can iterate only once
  1020. with warnings.catch_warnings(record=True) as w:
  1021. train_sizes2, train_scores2, test_scores2 = learning_curve(
  1022. estimator,
  1023. X,
  1024. y,
  1025. cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
  1026. train_sizes=np.linspace(0.1, 1.0, 10),
  1027. shuffle=shuffle_train,
  1028. )
  1029. if len(w) > 0:
  1030. raise RuntimeError("Unexpected warning: %r" % w[0].message)
  1031. assert_array_almost_equal(train_scores2, train_scores)
  1032. assert_array_almost_equal(test_scores2, test_scores)
  1033. def test_learning_curve_unsupervised():
  1034. X, _ = make_classification(
  1035. n_samples=30,
  1036. n_features=1,
  1037. n_informative=1,
  1038. n_redundant=0,
  1039. n_classes=2,
  1040. n_clusters_per_class=1,
  1041. random_state=0,
  1042. )
  1043. estimator = MockImprovingEstimator(20)
  1044. train_sizes, train_scores, test_scores = learning_curve(
  1045. estimator, X, y=None, cv=3, train_sizes=np.linspace(0.1, 1.0, 10)
  1046. )
  1047. assert_array_equal(train_sizes, np.linspace(2, 20, 10))
  1048. assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
  1049. assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
  1050. def test_learning_curve_verbose():
  1051. X, y = make_classification(
  1052. n_samples=30,
  1053. n_features=1,
  1054. n_informative=1,
  1055. n_redundant=0,
  1056. n_classes=2,
  1057. n_clusters_per_class=1,
  1058. random_state=0,
  1059. )
  1060. estimator = MockImprovingEstimator(20)
  1061. old_stdout = sys.stdout
  1062. sys.stdout = StringIO()
  1063. try:
  1064. train_sizes, train_scores, test_scores = learning_curve(
  1065. estimator, X, y, cv=3, verbose=1
  1066. )
  1067. finally:
  1068. out = sys.stdout.getvalue()
  1069. sys.stdout.close()
  1070. sys.stdout = old_stdout
  1071. assert "[learning_curve]" in out
  1072. def test_learning_curve_incremental_learning_not_possible():
  1073. X, y = make_classification(
  1074. n_samples=2,
  1075. n_features=1,
  1076. n_informative=1,
  1077. n_redundant=0,
  1078. n_classes=2,
  1079. n_clusters_per_class=1,
  1080. random_state=0,
  1081. )
  1082. # The mockup does not have partial_fit()
  1083. estimator = MockImprovingEstimator(1)
  1084. with pytest.raises(ValueError):
  1085. learning_curve(estimator, X, y, exploit_incremental_learning=True)
  1086. def test_learning_curve_incremental_learning():
  1087. X, y = make_classification(
  1088. n_samples=30,
  1089. n_features=1,
  1090. n_informative=1,
  1091. n_redundant=0,
  1092. n_classes=2,
  1093. n_clusters_per_class=1,
  1094. random_state=0,
  1095. )
  1096. estimator = MockIncrementalImprovingEstimator(20)
  1097. for shuffle_train in [False, True]:
  1098. train_sizes, train_scores, test_scores = learning_curve(
  1099. estimator,
  1100. X,
  1101. y,
  1102. cv=3,
  1103. exploit_incremental_learning=True,
  1104. train_sizes=np.linspace(0.1, 1.0, 10),
  1105. shuffle=shuffle_train,
  1106. )
  1107. assert_array_equal(train_sizes, np.linspace(2, 20, 10))
  1108. assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
  1109. assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
  1110. def test_learning_curve_incremental_learning_unsupervised():
  1111. X, _ = make_classification(
  1112. n_samples=30,
  1113. n_features=1,
  1114. n_informative=1,
  1115. n_redundant=0,
  1116. n_classes=2,
  1117. n_clusters_per_class=1,
  1118. random_state=0,
  1119. )
  1120. estimator = MockIncrementalImprovingEstimator(20)
  1121. train_sizes, train_scores, test_scores = learning_curve(
  1122. estimator,
  1123. X,
  1124. y=None,
  1125. cv=3,
  1126. exploit_incremental_learning=True,
  1127. train_sizes=np.linspace(0.1, 1.0, 10),
  1128. )
  1129. assert_array_equal(train_sizes, np.linspace(2, 20, 10))
  1130. assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
  1131. assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
  1132. def test_learning_curve_batch_and_incremental_learning_are_equal():
  1133. X, y = make_classification(
  1134. n_samples=30,
  1135. n_features=1,
  1136. n_informative=1,
  1137. n_redundant=0,
  1138. n_classes=2,
  1139. n_clusters_per_class=1,
  1140. random_state=0,
  1141. )
  1142. train_sizes = np.linspace(0.2, 1.0, 5)
  1143. estimator = PassiveAggressiveClassifier(max_iter=1, tol=None, shuffle=False)
  1144. train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
  1145. estimator,
  1146. X,
  1147. y,
  1148. train_sizes=train_sizes,
  1149. cv=3,
  1150. exploit_incremental_learning=True,
  1151. )
  1152. train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
  1153. estimator,
  1154. X,
  1155. y,
  1156. cv=3,
  1157. train_sizes=train_sizes,
  1158. exploit_incremental_learning=False,
  1159. )
  1160. assert_array_equal(train_sizes_inc, train_sizes_batch)
  1161. assert_array_almost_equal(
  1162. train_scores_inc.mean(axis=1), train_scores_batch.mean(axis=1)
  1163. )
  1164. assert_array_almost_equal(
  1165. test_scores_inc.mean(axis=1), test_scores_batch.mean(axis=1)
  1166. )
  1167. def test_learning_curve_n_sample_range_out_of_bounds():
  1168. X, y = make_classification(
  1169. n_samples=30,
  1170. n_features=1,
  1171. n_informative=1,
  1172. n_redundant=0,
  1173. n_classes=2,
  1174. n_clusters_per_class=1,
  1175. random_state=0,
  1176. )
  1177. estimator = MockImprovingEstimator(20)
  1178. with pytest.raises(ValueError):
  1179. learning_curve(estimator, X, y, cv=3, train_sizes=[0, 1])
  1180. with pytest.raises(ValueError):
  1181. learning_curve(estimator, X, y, cv=3, train_sizes=[0.0, 1.0])
  1182. with pytest.raises(ValueError):
  1183. learning_curve(estimator, X, y, cv=3, train_sizes=[0.1, 1.1])
  1184. with pytest.raises(ValueError):
  1185. learning_curve(estimator, X, y, cv=3, train_sizes=[0, 20])
  1186. with pytest.raises(ValueError):
  1187. learning_curve(estimator, X, y, cv=3, train_sizes=[1, 21])
  1188. def test_learning_curve_remove_duplicate_sample_sizes():
  1189. X, y = make_classification(
  1190. n_samples=3,
  1191. n_features=1,
  1192. n_informative=1,
  1193. n_redundant=0,
  1194. n_classes=2,
  1195. n_clusters_per_class=1,
  1196. random_state=0,
  1197. )
  1198. estimator = MockImprovingEstimator(2)
  1199. warning_message = (
  1200. "Removed duplicate entries from 'train_sizes'. Number of ticks "
  1201. "will be less than the size of 'train_sizes': 2 instead of 3."
  1202. )
  1203. with pytest.warns(RuntimeWarning, match=warning_message):
  1204. train_sizes, _, _ = learning_curve(
  1205. estimator, X, y, cv=3, train_sizes=np.linspace(0.33, 1.0, 3)
  1206. )
  1207. assert_array_equal(train_sizes, [1, 2])
  1208. def test_learning_curve_with_boolean_indices():
  1209. X, y = make_classification(
  1210. n_samples=30,
  1211. n_features=1,
  1212. n_informative=1,
  1213. n_redundant=0,
  1214. n_classes=2,
  1215. n_clusters_per_class=1,
  1216. random_state=0,
  1217. )
  1218. estimator = MockImprovingEstimator(20)
  1219. cv = KFold(n_splits=3)
  1220. train_sizes, train_scores, test_scores = learning_curve(
  1221. estimator, X, y, cv=cv, train_sizes=np.linspace(0.1, 1.0, 10)
  1222. )
  1223. assert_array_equal(train_sizes, np.linspace(2, 20, 10))
  1224. assert_array_almost_equal(train_scores.mean(axis=1), np.linspace(1.9, 1.0, 10))
  1225. assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10))
  1226. def test_learning_curve_with_shuffle():
  1227. # Following test case was designed this way to verify the code
  1228. # changes made in pull request: #7506.
  1229. X = np.array(
  1230. [
  1231. [1, 2],
  1232. [3, 4],
  1233. [5, 6],
  1234. [7, 8],
  1235. [11, 12],
  1236. [13, 14],
  1237. [15, 16],
  1238. [17, 18],
  1239. [19, 20],
  1240. [7, 8],
  1241. [9, 10],
  1242. [11, 12],
  1243. [13, 14],
  1244. [15, 16],
  1245. [17, 18],
  1246. ]
  1247. )
  1248. y = np.array([1, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 2, 3, 4])
  1249. groups = np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 4, 4, 4, 4])
  1250. # Splits on these groups fail without shuffle as the first iteration
  1251. # of the learning curve doesn't contain label 4 in the training set.
  1252. estimator = PassiveAggressiveClassifier(max_iter=5, tol=None, shuffle=False)
  1253. cv = GroupKFold(n_splits=2)
  1254. train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve(
  1255. estimator,
  1256. X,
  1257. y,
  1258. cv=cv,
  1259. n_jobs=1,
  1260. train_sizes=np.linspace(0.3, 1.0, 3),
  1261. groups=groups,
  1262. shuffle=True,
  1263. random_state=2,
  1264. )
  1265. assert_array_almost_equal(
  1266. train_scores_batch.mean(axis=1), np.array([0.75, 0.3, 0.36111111])
  1267. )
  1268. assert_array_almost_equal(
  1269. test_scores_batch.mean(axis=1), np.array([0.36111111, 0.25, 0.25])
  1270. )
  1271. with pytest.raises(ValueError):
  1272. learning_curve(
  1273. estimator,
  1274. X,
  1275. y,
  1276. cv=cv,
  1277. n_jobs=1,
  1278. train_sizes=np.linspace(0.3, 1.0, 3),
  1279. groups=groups,
  1280. error_score="raise",
  1281. )
  1282. train_sizes_inc, train_scores_inc, test_scores_inc = learning_curve(
  1283. estimator,
  1284. X,
  1285. y,
  1286. cv=cv,
  1287. n_jobs=1,
  1288. train_sizes=np.linspace(0.3, 1.0, 3),
  1289. groups=groups,
  1290. shuffle=True,
  1291. random_state=2,
  1292. exploit_incremental_learning=True,
  1293. )
  1294. assert_array_almost_equal(
  1295. train_scores_inc.mean(axis=1), train_scores_batch.mean(axis=1)
  1296. )
  1297. assert_array_almost_equal(
  1298. test_scores_inc.mean(axis=1), test_scores_batch.mean(axis=1)
  1299. )
  1300. def test_learning_curve_fit_params():
  1301. X = np.arange(100).reshape(10, 10)
  1302. y = np.array([0] * 5 + [1] * 5)
  1303. clf = CheckingClassifier(expected_sample_weight=True)
  1304. err_msg = r"Expected sample_weight to be passed"
  1305. with pytest.raises(AssertionError, match=err_msg):
  1306. learning_curve(clf, X, y, error_score="raise")
  1307. err_msg = r"sample_weight.shape == \(1,\), expected \(2,\)!"
  1308. with pytest.raises(ValueError, match=err_msg):
  1309. learning_curve(
  1310. clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(1)}
  1311. )
  1312. learning_curve(
  1313. clf, X, y, error_score="raise", fit_params={"sample_weight": np.ones(10)}
  1314. )
  1315. def test_learning_curve_incremental_learning_fit_params():
  1316. X, y = make_classification(
  1317. n_samples=30,
  1318. n_features=1,
  1319. n_informative=1,
  1320. n_redundant=0,
  1321. n_classes=2,
  1322. n_clusters_per_class=1,
  1323. random_state=0,
  1324. )
  1325. estimator = MockIncrementalImprovingEstimator(20, ["sample_weight"])
  1326. err_msg = r"Expected fit parameter\(s\) \['sample_weight'\] not seen."
  1327. with pytest.raises(AssertionError, match=err_msg):
  1328. learning_curve(
  1329. estimator,
  1330. X,
  1331. y,
  1332. cv=3,
  1333. exploit_incremental_learning=True,
  1334. train_sizes=np.linspace(0.1, 1.0, 10),
  1335. error_score="raise",
  1336. )
  1337. err_msg = "Fit parameter sample_weight has length 3; expected"
  1338. with pytest.raises(AssertionError, match=err_msg):
  1339. learning_curve(
  1340. estimator,
  1341. X,
  1342. y,
  1343. cv=3,
  1344. exploit_incremental_learning=True,
  1345. train_sizes=np.linspace(0.1, 1.0, 10),
  1346. error_score="raise",
  1347. fit_params={"sample_weight": np.ones(3)},
  1348. )
  1349. learning_curve(
  1350. estimator,
  1351. X,
  1352. y,
  1353. cv=3,
  1354. exploit_incremental_learning=True,
  1355. train_sizes=np.linspace(0.1, 1.0, 10),
  1356. error_score="raise",
  1357. fit_params={"sample_weight": np.ones(2)},
  1358. )
  1359. def test_validation_curve():
  1360. X, y = make_classification(
  1361. n_samples=2,
  1362. n_features=1,
  1363. n_informative=1,
  1364. n_redundant=0,
  1365. n_classes=2,
  1366. n_clusters_per_class=1,
  1367. random_state=0,
  1368. )
  1369. param_range = np.linspace(0, 1, 10)
  1370. with warnings.catch_warnings(record=True) as w:
  1371. train_scores, test_scores = validation_curve(
  1372. MockEstimatorWithParameter(),
  1373. X,
  1374. y,
  1375. param_name="param",
  1376. param_range=param_range,
  1377. cv=2,
  1378. )
  1379. if len(w) > 0:
  1380. raise RuntimeError("Unexpected warning: %r" % w[0].message)
  1381. assert_array_almost_equal(train_scores.mean(axis=1), param_range)
  1382. assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)
  1383. def test_validation_curve_clone_estimator():
  1384. X, y = make_classification(
  1385. n_samples=2,
  1386. n_features=1,
  1387. n_informative=1,
  1388. n_redundant=0,
  1389. n_classes=2,
  1390. n_clusters_per_class=1,
  1391. random_state=0,
  1392. )
  1393. param_range = np.linspace(1, 0, 10)
  1394. _, _ = validation_curve(
  1395. MockEstimatorWithSingleFitCallAllowed(),
  1396. X,
  1397. y,
  1398. param_name="param",
  1399. param_range=param_range,
  1400. cv=2,
  1401. )
  1402. def test_validation_curve_cv_splits_consistency():
  1403. n_samples = 100
  1404. n_splits = 5
  1405. X, y = make_classification(n_samples=100, random_state=0)
  1406. scores1 = validation_curve(
  1407. SVC(kernel="linear", random_state=0),
  1408. X,
  1409. y,
  1410. param_name="C",
  1411. param_range=[0.1, 0.1, 0.2, 0.2],
  1412. cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
  1413. )
  1414. # The OneTimeSplitter is a non-re-entrant cv splitter. Unless, the
  1415. # `split` is called for each parameter, the following should produce
  1416. # identical results for param setting 1 and param setting 2 as both have
  1417. # the same C value.
  1418. assert_array_almost_equal(*np.vsplit(np.hstack(scores1)[(0, 2, 1, 3), :], 2))
  1419. scores2 = validation_curve(
  1420. SVC(kernel="linear", random_state=0),
  1421. X,
  1422. y,
  1423. param_name="C",
  1424. param_range=[0.1, 0.1, 0.2, 0.2],
  1425. cv=KFold(n_splits=n_splits, shuffle=True),
  1426. )
  1427. # For scores2, compare the 1st and 2nd parameter's scores
  1428. # (Since the C value for 1st two param setting is 0.1, they must be
  1429. # consistent unless the train test folds differ between the param settings)
  1430. assert_array_almost_equal(*np.vsplit(np.hstack(scores2)[(0, 2, 1, 3), :], 2))
  1431. scores3 = validation_curve(
  1432. SVC(kernel="linear", random_state=0),
  1433. X,
  1434. y,
  1435. param_name="C",
  1436. param_range=[0.1, 0.1, 0.2, 0.2],
  1437. cv=KFold(n_splits=n_splits),
  1438. )
  1439. # OneTimeSplitter is basically unshuffled KFold(n_splits=5). Sanity check.
  1440. assert_array_almost_equal(np.array(scores3), np.array(scores1))
  1441. def test_validation_curve_fit_params():
  1442. X = np.arange(100).reshape(10, 10)
  1443. y = np.array([0] * 5 + [1] * 5)
  1444. clf = CheckingClassifier(expected_sample_weight=True)
  1445. err_msg = r"Expected sample_weight to be passed"
  1446. with pytest.raises(AssertionError, match=err_msg):
  1447. validation_curve(
  1448. clf,
  1449. X,
  1450. y,
  1451. param_name="foo_param",
  1452. param_range=[1, 2, 3],
  1453. error_score="raise",
  1454. )
  1455. err_msg = r"sample_weight.shape == \(1,\), expected \(8,\)!"
  1456. with pytest.raises(ValueError, match=err_msg):
  1457. validation_curve(
  1458. clf,
  1459. X,
  1460. y,
  1461. param_name="foo_param",
  1462. param_range=[1, 2, 3],
  1463. error_score="raise",
  1464. fit_params={"sample_weight": np.ones(1)},
  1465. )
  1466. validation_curve(
  1467. clf,
  1468. X,
  1469. y,
  1470. param_name="foo_param",
  1471. param_range=[1, 2, 3],
  1472. error_score="raise",
  1473. fit_params={"sample_weight": np.ones(10)},
  1474. )
  1475. def test_check_is_permutation():
  1476. rng = np.random.RandomState(0)
  1477. p = np.arange(100)
  1478. rng.shuffle(p)
  1479. assert _check_is_permutation(p, 100)
  1480. assert not _check_is_permutation(np.delete(p, 23), 100)
  1481. p[0] = 23
  1482. assert not _check_is_permutation(p, 100)
  1483. # Check if the additional duplicate indices are caught
  1484. assert not _check_is_permutation(np.hstack((p, 0)), 100)
  1485. def test_cross_val_predict_sparse_prediction():
  1486. # check that cross_val_predict gives same result for sparse and dense input
  1487. X, y = make_multilabel_classification(
  1488. n_classes=2,
  1489. n_labels=1,
  1490. allow_unlabeled=False,
  1491. return_indicator=True,
  1492. random_state=1,
  1493. )
  1494. X_sparse = csr_matrix(X)
  1495. y_sparse = csr_matrix(y)
  1496. classif = OneVsRestClassifier(SVC(kernel="linear"))
  1497. preds = cross_val_predict(classif, X, y, cv=10)
  1498. preds_sparse = cross_val_predict(classif, X_sparse, y_sparse, cv=10)
  1499. preds_sparse = preds_sparse.toarray()
  1500. assert_array_almost_equal(preds_sparse, preds)
  1501. def check_cross_val_predict_binary(est, X, y, method):
  1502. """Helper for tests of cross_val_predict with binary classification"""
  1503. cv = KFold(n_splits=3, shuffle=False)
  1504. # Generate expected outputs
  1505. if y.ndim == 1:
  1506. exp_shape = (len(X),) if method == "decision_function" else (len(X), 2)
  1507. else:
  1508. exp_shape = y.shape
  1509. expected_predictions = np.zeros(exp_shape)
  1510. for train, test in cv.split(X, y):
  1511. est = clone(est).fit(X[train], y[train])
  1512. expected_predictions[test] = getattr(est, method)(X[test])
  1513. # Check actual outputs for several representations of y
  1514. for tg in [y, y + 1, y - 2, y.astype("str")]:
  1515. assert_allclose(
  1516. cross_val_predict(est, X, tg, method=method, cv=cv), expected_predictions
  1517. )
  1518. def check_cross_val_predict_multiclass(est, X, y, method):
  1519. """Helper for tests of cross_val_predict with multiclass classification"""
  1520. cv = KFold(n_splits=3, shuffle=False)
  1521. # Generate expected outputs
  1522. float_min = np.finfo(np.float64).min
  1523. default_values = {
  1524. "decision_function": float_min,
  1525. "predict_log_proba": float_min,
  1526. "predict_proba": 0,
  1527. }
  1528. expected_predictions = np.full(
  1529. (len(X), len(set(y))), default_values[method], dtype=np.float64
  1530. )
  1531. _, y_enc = np.unique(y, return_inverse=True)
  1532. for train, test in cv.split(X, y_enc):
  1533. est = clone(est).fit(X[train], y_enc[train])
  1534. fold_preds = getattr(est, method)(X[test])
  1535. i_cols_fit = np.unique(y_enc[train])
  1536. expected_predictions[np.ix_(test, i_cols_fit)] = fold_preds
  1537. # Check actual outputs for several representations of y
  1538. for tg in [y, y + 1, y - 2, y.astype("str")]:
  1539. assert_allclose(
  1540. cross_val_predict(est, X, tg, method=method, cv=cv), expected_predictions
  1541. )
  1542. def check_cross_val_predict_multilabel(est, X, y, method):
  1543. """Check the output of cross_val_predict for 2D targets using
  1544. Estimators which provide a predictions as a list with one
  1545. element per class.
  1546. """
  1547. cv = KFold(n_splits=3, shuffle=False)
  1548. # Create empty arrays of the correct size to hold outputs
  1549. float_min = np.finfo(np.float64).min
  1550. default_values = {
  1551. "decision_function": float_min,
  1552. "predict_log_proba": float_min,
  1553. "predict_proba": 0,
  1554. }
  1555. n_targets = y.shape[1]
  1556. expected_preds = []
  1557. for i_col in range(n_targets):
  1558. n_classes_in_label = len(set(y[:, i_col]))
  1559. if n_classes_in_label == 2 and method == "decision_function":
  1560. exp_shape = (len(X),)
  1561. else:
  1562. exp_shape = (len(X), n_classes_in_label)
  1563. expected_preds.append(
  1564. np.full(exp_shape, default_values[method], dtype=np.float64)
  1565. )
  1566. # Generate expected outputs
  1567. y_enc_cols = [
  1568. np.unique(y[:, i], return_inverse=True)[1][:, np.newaxis]
  1569. for i in range(y.shape[1])
  1570. ]
  1571. y_enc = np.concatenate(y_enc_cols, axis=1)
  1572. for train, test in cv.split(X, y_enc):
  1573. est = clone(est).fit(X[train], y_enc[train])
  1574. fold_preds = getattr(est, method)(X[test])
  1575. for i_col in range(n_targets):
  1576. fold_cols = np.unique(y_enc[train][:, i_col])
  1577. if expected_preds[i_col].ndim == 1:
  1578. # Decision function with <=2 classes
  1579. expected_preds[i_col][test] = fold_preds[i_col]
  1580. else:
  1581. idx = np.ix_(test, fold_cols)
  1582. expected_preds[i_col][idx] = fold_preds[i_col]
  1583. # Check actual outputs for several representations of y
  1584. for tg in [y, y + 1, y - 2, y.astype("str")]:
  1585. cv_predict_output = cross_val_predict(est, X, tg, method=method, cv=cv)
  1586. assert len(cv_predict_output) == len(expected_preds)
  1587. for i in range(len(cv_predict_output)):
  1588. assert_allclose(cv_predict_output[i], expected_preds[i])
  1589. def check_cross_val_predict_with_method_binary(est):
  1590. # This test includes the decision_function with two classes.
  1591. # This is a special case: it has only one column of output.
  1592. X, y = make_classification(n_classes=2, random_state=0)
  1593. for method in ["decision_function", "predict_proba", "predict_log_proba"]:
  1594. check_cross_val_predict_binary(est, X, y, method)
  1595. def check_cross_val_predict_with_method_multiclass(est):
  1596. iris = load_iris()
  1597. X, y = iris.data, iris.target
  1598. X, y = shuffle(X, y, random_state=0)
  1599. for method in ["decision_function", "predict_proba", "predict_log_proba"]:
  1600. check_cross_val_predict_multiclass(est, X, y, method)
  1601. def test_cross_val_predict_with_method():
  1602. check_cross_val_predict_with_method_binary(LogisticRegression(solver="liblinear"))
  1603. check_cross_val_predict_with_method_multiclass(
  1604. LogisticRegression(solver="liblinear")
  1605. )
  1606. def test_cross_val_predict_method_checking():
  1607. # Regression test for issue #9639. Tests that cross_val_predict does not
  1608. # check estimator methods (e.g. predict_proba) before fitting
  1609. iris = load_iris()
  1610. X, y = iris.data, iris.target
  1611. X, y = shuffle(X, y, random_state=0)
  1612. for method in ["decision_function", "predict_proba", "predict_log_proba"]:
  1613. est = SGDClassifier(loss="log_loss", random_state=2)
  1614. check_cross_val_predict_multiclass(est, X, y, method)
  1615. def test_gridsearchcv_cross_val_predict_with_method():
  1616. iris = load_iris()
  1617. X, y = iris.data, iris.target
  1618. X, y = shuffle(X, y, random_state=0)
  1619. est = GridSearchCV(
  1620. LogisticRegression(random_state=42, solver="liblinear"), {"C": [0.1, 1]}, cv=2
  1621. )
  1622. for method in ["decision_function", "predict_proba", "predict_log_proba"]:
  1623. check_cross_val_predict_multiclass(est, X, y, method)
  1624. def test_cross_val_predict_with_method_multilabel_ovr():
  1625. # OVR does multilabel predictions, but only arrays of
  1626. # binary indicator columns. The output of predict_proba
  1627. # is a 2D array with shape (n_samples, n_classes).
  1628. n_samp = 100
  1629. n_classes = 4
  1630. X, y = make_multilabel_classification(
  1631. n_samples=n_samp, n_labels=3, n_classes=n_classes, n_features=5, random_state=42
  1632. )
  1633. est = OneVsRestClassifier(LogisticRegression(solver="liblinear", random_state=0))
  1634. for method in ["predict_proba", "decision_function"]:
  1635. check_cross_val_predict_binary(est, X, y, method=method)
  1636. class RFWithDecisionFunction(RandomForestClassifier):
  1637. # None of the current multioutput-multiclass estimators have
  1638. # decision function methods. Create a mock decision function
  1639. # to test the cross_val_predict function's handling of this case.
  1640. def decision_function(self, X):
  1641. probs = self.predict_proba(X)
  1642. msg = "This helper should only be used on multioutput-multiclass tasks"
  1643. assert isinstance(probs, list), msg
  1644. probs = [p[:, -1] if p.shape[1] == 2 else p for p in probs]
  1645. return probs
  1646. def test_cross_val_predict_with_method_multilabel_rf():
  1647. # The RandomForest allows multiple classes in each label.
  1648. # Output of predict_proba is a list of outputs of predict_proba
  1649. # for each individual label.
  1650. n_classes = 4
  1651. X, y = make_multilabel_classification(
  1652. n_samples=100, n_labels=3, n_classes=n_classes, n_features=5, random_state=42
  1653. )
  1654. y[:, 0] += y[:, 1] # Put three classes in the first column
  1655. for method in ["predict_proba", "predict_log_proba", "decision_function"]:
  1656. est = RFWithDecisionFunction(n_estimators=5, random_state=0)
  1657. with warnings.catch_warnings():
  1658. # Suppress "RuntimeWarning: divide by zero encountered in log"
  1659. warnings.simplefilter("ignore")
  1660. check_cross_val_predict_multilabel(est, X, y, method=method)
  1661. def test_cross_val_predict_with_method_rare_class():
  1662. # Test a multiclass problem where one class will be missing from
  1663. # one of the CV training sets.
  1664. rng = np.random.RandomState(0)
  1665. X = rng.normal(0, 1, size=(14, 10))
  1666. y = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 3])
  1667. est = LogisticRegression(solver="liblinear")
  1668. for method in ["predict_proba", "predict_log_proba", "decision_function"]:
  1669. with warnings.catch_warnings():
  1670. # Suppress warning about too few examples of a class
  1671. warnings.simplefilter("ignore")
  1672. check_cross_val_predict_multiclass(est, X, y, method)
  1673. def test_cross_val_predict_with_method_multilabel_rf_rare_class():
  1674. # The RandomForest allows anything for the contents of the labels.
  1675. # Output of predict_proba is a list of outputs of predict_proba
  1676. # for each individual label.
  1677. # In this test, the first label has a class with a single example.
  1678. # We'll have one CV fold where the training data don't include it.
  1679. rng = np.random.RandomState(0)
  1680. X = rng.normal(0, 1, size=(5, 10))
  1681. y = np.array([[0, 0], [1, 1], [2, 1], [0, 1], [1, 0]])
  1682. for method in ["predict_proba", "predict_log_proba"]:
  1683. est = RFWithDecisionFunction(n_estimators=5, random_state=0)
  1684. with warnings.catch_warnings():
  1685. # Suppress "RuntimeWarning: divide by zero encountered in log"
  1686. warnings.simplefilter("ignore")
  1687. check_cross_val_predict_multilabel(est, X, y, method=method)
  1688. def get_expected_predictions(X, y, cv, classes, est, method):
  1689. expected_predictions = np.zeros([len(y), classes])
  1690. func = getattr(est, method)
  1691. for train, test in cv.split(X, y):
  1692. est.fit(X[train], y[train])
  1693. expected_predictions_ = func(X[test])
  1694. # To avoid 2 dimensional indexing
  1695. if method == "predict_proba":
  1696. exp_pred_test = np.zeros((len(test), classes))
  1697. else:
  1698. exp_pred_test = np.full(
  1699. (len(test), classes), np.finfo(expected_predictions.dtype).min
  1700. )
  1701. exp_pred_test[:, est.classes_] = expected_predictions_
  1702. expected_predictions[test] = exp_pred_test
  1703. return expected_predictions
  1704. def test_cross_val_predict_class_subset():
  1705. X = np.arange(200).reshape(100, 2)
  1706. y = np.array([x // 10 for x in range(100)])
  1707. classes = 10
  1708. kfold3 = KFold(n_splits=3)
  1709. kfold4 = KFold(n_splits=4)
  1710. le = LabelEncoder()
  1711. methods = ["decision_function", "predict_proba", "predict_log_proba"]
  1712. for method in methods:
  1713. est = LogisticRegression(solver="liblinear")
  1714. # Test with n_splits=3
  1715. predictions = cross_val_predict(est, X, y, method=method, cv=kfold3)
  1716. # Runs a naive loop (should be same as cross_val_predict):
  1717. expected_predictions = get_expected_predictions(
  1718. X, y, kfold3, classes, est, method
  1719. )
  1720. assert_array_almost_equal(expected_predictions, predictions)
  1721. # Test with n_splits=4
  1722. predictions = cross_val_predict(est, X, y, method=method, cv=kfold4)
  1723. expected_predictions = get_expected_predictions(
  1724. X, y, kfold4, classes, est, method
  1725. )
  1726. assert_array_almost_equal(expected_predictions, predictions)
  1727. # Testing unordered labels
  1728. y = shuffle(np.repeat(range(10), 10), random_state=0)
  1729. predictions = cross_val_predict(est, X, y, method=method, cv=kfold3)
  1730. y = le.fit_transform(y)
  1731. expected_predictions = get_expected_predictions(
  1732. X, y, kfold3, classes, est, method
  1733. )
  1734. assert_array_almost_equal(expected_predictions, predictions)
  1735. def test_score_memmap():
  1736. # Ensure a scalar score of memmap type is accepted
  1737. iris = load_iris()
  1738. X, y = iris.data, iris.target
  1739. clf = MockClassifier()
  1740. tf = tempfile.NamedTemporaryFile(mode="wb", delete=False)
  1741. tf.write(b"Hello world!!!!!")
  1742. tf.close()
  1743. scores = np.memmap(tf.name, dtype=np.float64)
  1744. score = np.memmap(tf.name, shape=(), mode="r", dtype=np.float64)
  1745. try:
  1746. cross_val_score(clf, X, y, scoring=lambda est, X, y: score)
  1747. with pytest.raises(ValueError):
  1748. cross_val_score(clf, X, y, scoring=lambda est, X, y: scores)
  1749. finally:
  1750. # Best effort to release the mmap file handles before deleting the
  1751. # backing file under Windows
  1752. scores, score = None, None
  1753. for _ in range(3):
  1754. try:
  1755. os.unlink(tf.name)
  1756. break
  1757. except OSError:
  1758. sleep(1.0)
  1759. @pytest.mark.filterwarnings("ignore: Using or importing the ABCs from")
  1760. def test_permutation_test_score_pandas():
  1761. # check permutation_test_score doesn't destroy pandas dataframe
  1762. types = [(MockDataFrame, MockDataFrame)]
  1763. try:
  1764. from pandas import DataFrame, Series
  1765. types.append((Series, DataFrame))
  1766. except ImportError:
  1767. pass
  1768. for TargetType, InputFeatureType in types:
  1769. # X dataframe, y series
  1770. iris = load_iris()
  1771. X, y = iris.data, iris.target
  1772. X_df, y_ser = InputFeatureType(X), TargetType(y)
  1773. check_df = lambda x: isinstance(x, InputFeatureType)
  1774. check_series = lambda x: isinstance(x, TargetType)
  1775. clf = CheckingClassifier(check_X=check_df, check_y=check_series)
  1776. permutation_test_score(clf, X_df, y_ser)
  1777. def test_fit_and_score_failing():
  1778. # Create a failing classifier to deliberately fail
  1779. failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
  1780. # dummy X data
  1781. X = np.arange(1, 10)
  1782. fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
  1783. # passing error score to trigger the warning message
  1784. fit_and_score_kwargs = {"error_score": "raise"}
  1785. # check if exception was raised, with default error_score='raise'
  1786. with pytest.raises(ValueError, match="Failing classifier failed as required"):
  1787. _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
  1788. # check that functions upstream pass error_score param to _fit_and_score
  1789. error_message_cross_validate = (
  1790. "The 'error_score' parameter of cross_validate must be .*. Got .* instead."
  1791. )
  1792. with pytest.raises(ValueError, match=error_message_cross_validate):
  1793. cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")
  1794. assert failing_clf.score() == 0.0 # FailingClassifier coverage
  1795. def test_fit_and_score_working():
  1796. X, y = make_classification(n_samples=30, random_state=0)
  1797. clf = SVC(kernel="linear", random_state=0)
  1798. train, test = next(ShuffleSplit().split(X))
  1799. # Test return_parameters option
  1800. fit_and_score_args = [clf, X, y, dict(), train, test, 0]
  1801. fit_and_score_kwargs = {
  1802. "parameters": {"max_iter": 100, "tol": 0.1},
  1803. "fit_params": None,
  1804. "return_parameters": True,
  1805. }
  1806. result = _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
  1807. assert result["parameters"] == fit_and_score_kwargs["parameters"]
  1808. class DataDependentFailingClassifier(BaseEstimator):
  1809. def __init__(self, max_x_value=None):
  1810. self.max_x_value = max_x_value
  1811. def fit(self, X, y=None):
  1812. num_values_too_high = (X > self.max_x_value).sum()
  1813. if num_values_too_high:
  1814. raise ValueError(
  1815. f"Classifier fit failed with {num_values_too_high} values too high"
  1816. )
  1817. def score(self, X=None, Y=None):
  1818. return 0.0
  1819. @pytest.mark.parametrize("error_score", [np.nan, 0])
  1820. def test_cross_validate_some_failing_fits_warning(error_score):
  1821. # Create a failing classifier to deliberately fail
  1822. failing_clf = DataDependentFailingClassifier(max_x_value=8)
  1823. # dummy X data
  1824. X = np.arange(1, 10)
  1825. y = np.ones(9)
  1826. # passing error score to trigger the warning message
  1827. cross_validate_args = [failing_clf, X, y]
  1828. cross_validate_kwargs = {"cv": 3, "error_score": error_score}
  1829. # check if the warning message type is as expected
  1830. individual_fit_error_message = (
  1831. "ValueError: Classifier fit failed with 1 values too high"
  1832. )
  1833. warning_message = re.compile(
  1834. (
  1835. "2 fits failed.+total of 3.+The score on these"
  1836. " train-test partitions for these parameters will be set to"
  1837. f" {cross_validate_kwargs['error_score']}.+{individual_fit_error_message}"
  1838. ),
  1839. flags=re.DOTALL,
  1840. )
  1841. with pytest.warns(FitFailedWarning, match=warning_message):
  1842. cross_validate(*cross_validate_args, **cross_validate_kwargs)
  1843. @pytest.mark.parametrize("error_score", [np.nan, 0])
  1844. def test_cross_validate_all_failing_fits_error(error_score):
  1845. # Create a failing classifier to deliberately fail
  1846. failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
  1847. # dummy X data
  1848. X = np.arange(1, 10)
  1849. y = np.ones(9)
  1850. cross_validate_args = [failing_clf, X, y]
  1851. cross_validate_kwargs = {"cv": 7, "error_score": error_score}
  1852. individual_fit_error_message = "ValueError: Failing classifier failed as required"
  1853. error_message = re.compile(
  1854. (
  1855. "All the 7 fits failed.+your model is misconfigured.+"
  1856. f"{individual_fit_error_message}"
  1857. ),
  1858. flags=re.DOTALL,
  1859. )
  1860. with pytest.raises(ValueError, match=error_message):
  1861. cross_validate(*cross_validate_args, **cross_validate_kwargs)
  1862. def _failing_scorer(estimator, X, y, error_msg):
  1863. raise ValueError(error_msg)
  1864. @pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
  1865. @pytest.mark.parametrize("error_score", [np.nan, 0, "raise"])
  1866. def test_cross_val_score_failing_scorer(error_score):
  1867. # check that an estimator can fail during scoring in `cross_val_score` and
  1868. # that we can optionally replaced it with `error_score`
  1869. X, y = load_iris(return_X_y=True)
  1870. clf = LogisticRegression(max_iter=5).fit(X, y)
  1871. error_msg = "This scorer is supposed to fail!!!"
  1872. failing_scorer = partial(_failing_scorer, error_msg=error_msg)
  1873. if error_score == "raise":
  1874. with pytest.raises(ValueError, match=error_msg):
  1875. cross_val_score(
  1876. clf, X, y, cv=3, scoring=failing_scorer, error_score=error_score
  1877. )
  1878. else:
  1879. warning_msg = (
  1880. "Scoring failed. The score on this train-test partition for "
  1881. f"these parameters will be set to {error_score}"
  1882. )
  1883. with pytest.warns(UserWarning, match=warning_msg):
  1884. scores = cross_val_score(
  1885. clf, X, y, cv=3, scoring=failing_scorer, error_score=error_score
  1886. )
  1887. assert_allclose(scores, error_score)
  1888. @pytest.mark.filterwarnings("ignore:lbfgs failed to converge")
  1889. @pytest.mark.parametrize("error_score", [np.nan, 0, "raise"])
  1890. @pytest.mark.parametrize("return_train_score", [True, False])
  1891. @pytest.mark.parametrize("with_multimetric", [False, True])
  1892. def test_cross_validate_failing_scorer(
  1893. error_score, return_train_score, with_multimetric
  1894. ):
  1895. # Check that an estimator can fail during scoring in `cross_validate` and
  1896. # that we can optionally replace it with `error_score`. In the multimetric
  1897. # case also check the result of a non-failing scorer where the other scorers
  1898. # are failing.
  1899. X, y = load_iris(return_X_y=True)
  1900. clf = LogisticRegression(max_iter=5).fit(X, y)
  1901. error_msg = "This scorer is supposed to fail!!!"
  1902. failing_scorer = partial(_failing_scorer, error_msg=error_msg)
  1903. if with_multimetric:
  1904. non_failing_scorer = make_scorer(mean_squared_error)
  1905. scoring = {
  1906. "score_1": failing_scorer,
  1907. "score_2": non_failing_scorer,
  1908. "score_3": failing_scorer,
  1909. }
  1910. else:
  1911. scoring = failing_scorer
  1912. if error_score == "raise":
  1913. with pytest.raises(ValueError, match=error_msg):
  1914. cross_validate(
  1915. clf,
  1916. X,
  1917. y,
  1918. cv=3,
  1919. scoring=scoring,
  1920. return_train_score=return_train_score,
  1921. error_score=error_score,
  1922. )
  1923. else:
  1924. warning_msg = (
  1925. "Scoring failed. The score on this train-test partition for "
  1926. f"these parameters will be set to {error_score}"
  1927. )
  1928. with pytest.warns(UserWarning, match=warning_msg):
  1929. results = cross_validate(
  1930. clf,
  1931. X,
  1932. y,
  1933. cv=3,
  1934. scoring=scoring,
  1935. return_train_score=return_train_score,
  1936. error_score=error_score,
  1937. )
  1938. for key in results:
  1939. if "_score" in key:
  1940. if "_score_2" in key:
  1941. # check the test (and optionally train) score for the
  1942. # scorer that should be non-failing
  1943. for i in results[key]:
  1944. assert isinstance(i, float)
  1945. else:
  1946. # check the test (and optionally train) score for all
  1947. # scorers that should be assigned to `error_score`.
  1948. assert_allclose(results[key], error_score)
  1949. def three_params_scorer(i, j, k):
  1950. return 3.4213
  1951. @pytest.mark.parametrize(
  1952. "train_score, scorer, verbose, split_prg, cdt_prg, expected",
  1953. [
  1954. (
  1955. False,
  1956. three_params_scorer,
  1957. 2,
  1958. (1, 3),
  1959. (0, 1),
  1960. r"\[CV\] END ...................................................."
  1961. r" total time= 0.\ds",
  1962. ),
  1963. (
  1964. True,
  1965. {"sc1": three_params_scorer, "sc2": three_params_scorer},
  1966. 3,
  1967. (1, 3),
  1968. (0, 1),
  1969. r"\[CV 2/3\] END sc1: \(train=3.421, test=3.421\) sc2: "
  1970. r"\(train=3.421, test=3.421\) total time= 0.\ds",
  1971. ),
  1972. (
  1973. False,
  1974. {"sc1": three_params_scorer, "sc2": three_params_scorer},
  1975. 10,
  1976. (1, 3),
  1977. (0, 1),
  1978. r"\[CV 2/3; 1/1\] END ....... sc1: \(test=3.421\) sc2: \(test=3.421\)"
  1979. r" total time= 0.\ds",
  1980. ),
  1981. ],
  1982. )
  1983. def test_fit_and_score_verbosity(
  1984. capsys, train_score, scorer, verbose, split_prg, cdt_prg, expected
  1985. ):
  1986. X, y = make_classification(n_samples=30, random_state=0)
  1987. clf = SVC(kernel="linear", random_state=0)
  1988. train, test = next(ShuffleSplit().split(X))
  1989. # test print without train score
  1990. fit_and_score_args = [clf, X, y, scorer, train, test, verbose, None, None]
  1991. fit_and_score_kwargs = {
  1992. "return_train_score": train_score,
  1993. "split_progress": split_prg,
  1994. "candidate_progress": cdt_prg,
  1995. }
  1996. _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
  1997. out, _ = capsys.readouterr()
  1998. outlines = out.split("\n")
  1999. if len(outlines) > 2:
  2000. assert re.match(expected, outlines[1])
  2001. else:
  2002. assert re.match(expected, outlines[0])
  2003. def test_score():
  2004. error_message = "scoring must return a number, got None"
  2005. def two_params_scorer(estimator, X_test):
  2006. return None
  2007. fit_and_score_args = [None, None, None, two_params_scorer]
  2008. with pytest.raises(ValueError, match=error_message):
  2009. _score(*fit_and_score_args, error_score=np.nan)
  2010. def test_callable_multimetric_confusion_matrix_cross_validate():
  2011. def custom_scorer(clf, X, y):
  2012. y_pred = clf.predict(X)
  2013. cm = confusion_matrix(y, y_pred)
  2014. return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
  2015. X, y = make_classification(n_samples=40, n_features=4, random_state=42)
  2016. est = LinearSVC(dual="auto", random_state=42)
  2017. est.fit(X, y)
  2018. cv_results = cross_validate(est, X, y, cv=5, scoring=custom_scorer)
  2019. score_names = ["tn", "fp", "fn", "tp"]
  2020. for name in score_names:
  2021. assert "test_{}".format(name) in cv_results
  2022. def test_learning_curve_partial_fit_regressors():
  2023. """Check that regressors with partial_fit is supported.
  2024. Non-regression test for #22981.
  2025. """
  2026. X, y = make_regression(random_state=42)
  2027. # Does not error
  2028. learning_curve(MLPRegressor(), X, y, exploit_incremental_learning=True, cv=2)
  2029. def test_cross_validate_return_indices(global_random_seed):
  2030. """Check the behaviour of `return_indices` in `cross_validate`."""
  2031. X, y = load_iris(return_X_y=True)
  2032. X = scale(X) # scale features for better convergence
  2033. estimator = LogisticRegression()
  2034. cv = KFold(n_splits=3, shuffle=True, random_state=global_random_seed)
  2035. cv_results = cross_validate(estimator, X, y, cv=cv, n_jobs=2, return_indices=False)
  2036. assert "indices" not in cv_results
  2037. cv_results = cross_validate(estimator, X, y, cv=cv, n_jobs=2, return_indices=True)
  2038. assert "indices" in cv_results
  2039. train_indices = cv_results["indices"]["train"]
  2040. test_indices = cv_results["indices"]["test"]
  2041. assert len(train_indices) == cv.n_splits
  2042. assert len(test_indices) == cv.n_splits
  2043. assert_array_equal([indices.size for indices in train_indices], 100)
  2044. assert_array_equal([indices.size for indices in test_indices], 50)
  2045. for split_idx, (expected_train_idx, expected_test_idx) in enumerate(cv.split(X, y)):
  2046. assert_array_equal(train_indices[split_idx], expected_train_idx)
  2047. assert_array_equal(test_indices[split_idx], expected_test_idx)