test_search.py 80 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451
  1. """Test the search module"""
  2. import pickle
  3. import re
  4. import sys
  5. from collections.abc import Iterable, Sized
  6. from functools import partial
  7. from io import StringIO
  8. from itertools import chain, product
  9. from types import GeneratorType
  10. import numpy as np
  11. import pytest
  12. import scipy.sparse as sp
  13. from scipy.stats import bernoulli, expon, uniform
  14. from sklearn.base import BaseEstimator, ClassifierMixin, is_classifier
  15. from sklearn.cluster import KMeans
  16. from sklearn.datasets import (
  17. make_blobs,
  18. make_classification,
  19. make_multilabel_classification,
  20. )
  21. from sklearn.ensemble import HistGradientBoostingClassifier
  22. from sklearn.experimental import enable_halving_search_cv # noqa
  23. from sklearn.impute import SimpleImputer
  24. from sklearn.linear_model import LinearRegression, Ridge, SGDClassifier
  25. from sklearn.metrics import (
  26. accuracy_score,
  27. confusion_matrix,
  28. f1_score,
  29. make_scorer,
  30. r2_score,
  31. recall_score,
  32. roc_auc_score,
  33. )
  34. from sklearn.metrics.pairwise import euclidean_distances
  35. from sklearn.model_selection import (
  36. GridSearchCV,
  37. GroupKFold,
  38. GroupShuffleSplit,
  39. HalvingGridSearchCV,
  40. KFold,
  41. LeaveOneGroupOut,
  42. LeavePGroupsOut,
  43. ParameterGrid,
  44. ParameterSampler,
  45. RandomizedSearchCV,
  46. StratifiedKFold,
  47. StratifiedShuffleSplit,
  48. train_test_split,
  49. )
  50. from sklearn.model_selection._search import BaseSearchCV
  51. from sklearn.model_selection._validation import FitFailedWarning
  52. from sklearn.model_selection.tests.common import OneTimeSplitter
  53. from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
  54. from sklearn.pipeline import Pipeline
  55. from sklearn.svm import SVC, LinearSVC
  56. from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
  57. from sklearn.utils._mocking import CheckingClassifier, MockDataFrame
  58. from sklearn.utils._testing import (
  59. MinimalClassifier,
  60. MinimalRegressor,
  61. MinimalTransformer,
  62. assert_allclose,
  63. assert_almost_equal,
  64. assert_array_almost_equal,
  65. assert_array_equal,
  66. ignore_warnings,
  67. )
  68. # Neither of the following two estimators inherit from BaseEstimator,
  69. # to test hyperparameter search on user-defined classifiers.
  70. class MockClassifier:
  71. """Dummy classifier to test the parameter search algorithms"""
  72. def __init__(self, foo_param=0):
  73. self.foo_param = foo_param
  74. def fit(self, X, Y):
  75. assert len(X) == len(Y)
  76. self.classes_ = np.unique(Y)
  77. return self
  78. def predict(self, T):
  79. return T.shape[0]
  80. def transform(self, X):
  81. return X + self.foo_param
  82. def inverse_transform(self, X):
  83. return X - self.foo_param
  84. predict_proba = predict
  85. predict_log_proba = predict
  86. decision_function = predict
  87. def score(self, X=None, Y=None):
  88. if self.foo_param > 1:
  89. score = 1.0
  90. else:
  91. score = 0.0
  92. return score
  93. def get_params(self, deep=False):
  94. return {"foo_param": self.foo_param}
  95. def set_params(self, **params):
  96. self.foo_param = params["foo_param"]
  97. return self
  98. class LinearSVCNoScore(LinearSVC):
  99. """A LinearSVC classifier that has no score method."""
  100. @property
  101. def score(self):
  102. raise AttributeError
  103. X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
  104. y = np.array([1, 1, 2, 2])
  105. def assert_grid_iter_equals_getitem(grid):
  106. assert list(grid) == [grid[i] for i in range(len(grid))]
  107. @pytest.mark.parametrize("klass", [ParameterGrid, partial(ParameterSampler, n_iter=10)])
  108. @pytest.mark.parametrize(
  109. "input, error_type, error_message",
  110. [
  111. (0, TypeError, r"Parameter .* a dict or a list, got: 0 of type int"),
  112. ([{"foo": [0]}, 0], TypeError, r"Parameter .* is not a dict \(0\)"),
  113. (
  114. {"foo": 0},
  115. TypeError,
  116. r"Parameter (grid|distribution) for parameter 'foo' (is not|needs to be) "
  117. r"(a list or a numpy array|iterable or a distribution).*",
  118. ),
  119. ],
  120. )
  121. def test_validate_parameter_input(klass, input, error_type, error_message):
  122. with pytest.raises(error_type, match=error_message):
  123. klass(input)
  124. def test_parameter_grid():
  125. # Test basic properties of ParameterGrid.
  126. params1 = {"foo": [1, 2, 3]}
  127. grid1 = ParameterGrid(params1)
  128. assert isinstance(grid1, Iterable)
  129. assert isinstance(grid1, Sized)
  130. assert len(grid1) == 3
  131. assert_grid_iter_equals_getitem(grid1)
  132. params2 = {"foo": [4, 2], "bar": ["ham", "spam", "eggs"]}
  133. grid2 = ParameterGrid(params2)
  134. assert len(grid2) == 6
  135. # loop to assert we can iterate over the grid multiple times
  136. for i in range(2):
  137. # tuple + chain transforms {"a": 1, "b": 2} to ("a", 1, "b", 2)
  138. points = set(tuple(chain(*(sorted(p.items())))) for p in grid2)
  139. assert points == set(
  140. ("bar", x, "foo", y) for x, y in product(params2["bar"], params2["foo"])
  141. )
  142. assert_grid_iter_equals_getitem(grid2)
  143. # Special case: empty grid (useful to get default estimator settings)
  144. empty = ParameterGrid({})
  145. assert len(empty) == 1
  146. assert list(empty) == [{}]
  147. assert_grid_iter_equals_getitem(empty)
  148. with pytest.raises(IndexError):
  149. empty[1]
  150. has_empty = ParameterGrid([{"C": [1, 10]}, {}, {"C": [0.5]}])
  151. assert len(has_empty) == 4
  152. assert list(has_empty) == [{"C": 1}, {"C": 10}, {}, {"C": 0.5}]
  153. assert_grid_iter_equals_getitem(has_empty)
  154. def test_grid_search():
  155. # Test that the best estimator contains the right value for foo_param
  156. clf = MockClassifier()
  157. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
  158. # make sure it selects the smallest parameter in case of ties
  159. old_stdout = sys.stdout
  160. sys.stdout = StringIO()
  161. grid_search.fit(X, y)
  162. sys.stdout = old_stdout
  163. assert grid_search.best_estimator_.foo_param == 2
  164. assert_array_equal(grid_search.cv_results_["param_foo_param"].data, [1, 2, 3])
  165. # Smoke test the score etc:
  166. grid_search.score(X, y)
  167. grid_search.predict_proba(X)
  168. grid_search.decision_function(X)
  169. grid_search.transform(X)
  170. # Test exception handling on scoring
  171. grid_search.scoring = "sklearn"
  172. with pytest.raises(ValueError):
  173. grid_search.fit(X, y)
  174. def test_grid_search_pipeline_steps():
  175. # check that parameters that are estimators are cloned before fitting
  176. pipe = Pipeline([("regressor", LinearRegression())])
  177. param_grid = {"regressor": [LinearRegression(), Ridge()]}
  178. grid_search = GridSearchCV(pipe, param_grid, cv=2)
  179. grid_search.fit(X, y)
  180. regressor_results = grid_search.cv_results_["param_regressor"]
  181. assert isinstance(regressor_results[0], LinearRegression)
  182. assert isinstance(regressor_results[1], Ridge)
  183. assert not hasattr(regressor_results[0], "coef_")
  184. assert not hasattr(regressor_results[1], "coef_")
  185. assert regressor_results[0] is not grid_search.best_estimator_
  186. assert regressor_results[1] is not grid_search.best_estimator_
  187. # check that we didn't modify the parameter grid that was passed
  188. assert not hasattr(param_grid["regressor"][0], "coef_")
  189. assert not hasattr(param_grid["regressor"][1], "coef_")
  190. @pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
  191. def test_SearchCV_with_fit_params(SearchCV):
  192. X = np.arange(100).reshape(10, 10)
  193. y = np.array([0] * 5 + [1] * 5)
  194. clf = CheckingClassifier(expected_fit_params=["spam", "eggs"])
  195. searcher = SearchCV(clf, {"foo_param": [1, 2, 3]}, cv=2, error_score="raise")
  196. # The CheckingClassifier generates an assertion error if
  197. # a parameter is missing or has length != len(X).
  198. err_msg = r"Expected fit parameter\(s\) \['eggs'\] not seen."
  199. with pytest.raises(AssertionError, match=err_msg):
  200. searcher.fit(X, y, spam=np.ones(10))
  201. err_msg = "Fit parameter spam has length 1; expected"
  202. with pytest.raises(AssertionError, match=err_msg):
  203. searcher.fit(X, y, spam=np.ones(1), eggs=np.zeros(10))
  204. searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))
  205. @ignore_warnings
  206. def test_grid_search_no_score():
  207. # Test grid-search on classifier that has no score function.
  208. clf = LinearSVC(dual="auto", random_state=0)
  209. X, y = make_blobs(random_state=0, centers=2)
  210. Cs = [0.1, 1, 10]
  211. clf_no_score = LinearSVCNoScore(dual="auto", random_state=0)
  212. grid_search = GridSearchCV(clf, {"C": Cs}, scoring="accuracy")
  213. grid_search.fit(X, y)
  214. grid_search_no_score = GridSearchCV(clf_no_score, {"C": Cs}, scoring="accuracy")
  215. # smoketest grid search
  216. grid_search_no_score.fit(X, y)
  217. # check that best params are equal
  218. assert grid_search_no_score.best_params_ == grid_search.best_params_
  219. # check that we can call score and that it gives the correct result
  220. assert grid_search.score(X, y) == grid_search_no_score.score(X, y)
  221. # giving no scoring function raises an error
  222. grid_search_no_score = GridSearchCV(clf_no_score, {"C": Cs})
  223. with pytest.raises(TypeError, match="no scoring"):
  224. grid_search_no_score.fit([[1]])
  225. def test_grid_search_score_method():
  226. X, y = make_classification(n_samples=100, n_classes=2, flip_y=0.2, random_state=0)
  227. clf = LinearSVC(dual="auto", random_state=0)
  228. grid = {"C": [0.1]}
  229. search_no_scoring = GridSearchCV(clf, grid, scoring=None).fit(X, y)
  230. search_accuracy = GridSearchCV(clf, grid, scoring="accuracy").fit(X, y)
  231. search_no_score_method_auc = GridSearchCV(
  232. LinearSVCNoScore(dual="auto"), grid, scoring="roc_auc"
  233. ).fit(X, y)
  234. search_auc = GridSearchCV(clf, grid, scoring="roc_auc").fit(X, y)
  235. # Check warning only occurs in situation where behavior changed:
  236. # estimator requires score method to compete with scoring parameter
  237. score_no_scoring = search_no_scoring.score(X, y)
  238. score_accuracy = search_accuracy.score(X, y)
  239. score_no_score_auc = search_no_score_method_auc.score(X, y)
  240. score_auc = search_auc.score(X, y)
  241. # ensure the test is sane
  242. assert score_auc < 1.0
  243. assert score_accuracy < 1.0
  244. assert score_auc != score_accuracy
  245. assert_almost_equal(score_accuracy, score_no_scoring)
  246. assert_almost_equal(score_auc, score_no_score_auc)
  247. def test_grid_search_groups():
  248. # Check if ValueError (when groups is None) propagates to GridSearchCV
  249. # And also check if groups is correctly passed to the cv object
  250. rng = np.random.RandomState(0)
  251. X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
  252. groups = rng.randint(0, 3, 15)
  253. clf = LinearSVC(dual="auto", random_state=0)
  254. grid = {"C": [1]}
  255. group_cvs = [
  256. LeaveOneGroupOut(),
  257. LeavePGroupsOut(2),
  258. GroupKFold(n_splits=3),
  259. GroupShuffleSplit(),
  260. ]
  261. error_msg = "The 'groups' parameter should not be None."
  262. for cv in group_cvs:
  263. gs = GridSearchCV(clf, grid, cv=cv)
  264. with pytest.raises(ValueError, match=error_msg):
  265. gs.fit(X, y)
  266. gs.fit(X, y, groups=groups)
  267. non_group_cvs = [StratifiedKFold(), StratifiedShuffleSplit()]
  268. for cv in non_group_cvs:
  269. gs = GridSearchCV(clf, grid, cv=cv)
  270. # Should not raise an error
  271. gs.fit(X, y)
  272. def test_classes__property():
  273. # Test that classes_ property matches best_estimator_.classes_
  274. X = np.arange(100).reshape(10, 10)
  275. y = np.array([0] * 5 + [1] * 5)
  276. Cs = [0.1, 1, 10]
  277. grid_search = GridSearchCV(LinearSVC(dual="auto", random_state=0), {"C": Cs})
  278. grid_search.fit(X, y)
  279. assert_array_equal(grid_search.best_estimator_.classes_, grid_search.classes_)
  280. # Test that regressors do not have a classes_ attribute
  281. grid_search = GridSearchCV(Ridge(), {"alpha": [1.0, 2.0]})
  282. grid_search.fit(X, y)
  283. assert not hasattr(grid_search, "classes_")
  284. # Test that the grid searcher has no classes_ attribute before it's fit
  285. grid_search = GridSearchCV(LinearSVC(dual="auto", random_state=0), {"C": Cs})
  286. assert not hasattr(grid_search, "classes_")
  287. # Test that the grid searcher has no classes_ attribute without a refit
  288. grid_search = GridSearchCV(
  289. LinearSVC(dual="auto", random_state=0), {"C": Cs}, refit=False
  290. )
  291. grid_search.fit(X, y)
  292. assert not hasattr(grid_search, "classes_")
  293. def test_trivial_cv_results_attr():
  294. # Test search over a "grid" with only one point.
  295. clf = MockClassifier()
  296. grid_search = GridSearchCV(clf, {"foo_param": [1]}, cv=3)
  297. grid_search.fit(X, y)
  298. assert hasattr(grid_search, "cv_results_")
  299. random_search = RandomizedSearchCV(clf, {"foo_param": [0]}, n_iter=1, cv=3)
  300. random_search.fit(X, y)
  301. assert hasattr(grid_search, "cv_results_")
  302. def test_no_refit():
  303. # Test that GSCV can be used for model selection alone without refitting
  304. clf = MockClassifier()
  305. for scoring in [None, ["accuracy", "precision"]]:
  306. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, refit=False, cv=3)
  307. grid_search.fit(X, y)
  308. assert (
  309. not hasattr(grid_search, "best_estimator_")
  310. and hasattr(grid_search, "best_index_")
  311. and hasattr(grid_search, "best_params_")
  312. )
  313. # Make sure the functions predict/transform etc. raise meaningful
  314. # error messages
  315. for fn_name in (
  316. "predict",
  317. "predict_proba",
  318. "predict_log_proba",
  319. "transform",
  320. "inverse_transform",
  321. ):
  322. error_msg = (
  323. f"`refit=False`. {fn_name} is available only after "
  324. "refitting on the best parameters"
  325. )
  326. with pytest.raises(AttributeError, match=error_msg):
  327. getattr(grid_search, fn_name)(X)
  328. # Test that an invalid refit param raises appropriate error messages
  329. error_msg = (
  330. "For multi-metric scoring, the parameter refit must be set to a scorer key"
  331. )
  332. for refit in [True, "recall", "accuracy"]:
  333. with pytest.raises(ValueError, match=error_msg):
  334. GridSearchCV(
  335. clf, {}, refit=refit, scoring={"acc": "accuracy", "prec": "precision"}
  336. ).fit(X, y)
  337. def test_grid_search_error():
  338. # Test that grid search will capture errors on data with different length
  339. X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
  340. clf = LinearSVC(dual="auto")
  341. cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
  342. with pytest.raises(ValueError):
  343. cv.fit(X_[:180], y_)
  344. def test_grid_search_one_grid_point():
  345. X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
  346. param_dict = {"C": [1.0], "kernel": ["rbf"], "gamma": [0.1]}
  347. clf = SVC(gamma="auto")
  348. cv = GridSearchCV(clf, param_dict)
  349. cv.fit(X_, y_)
  350. clf = SVC(C=1.0, kernel="rbf", gamma=0.1)
  351. clf.fit(X_, y_)
  352. assert_array_equal(clf.dual_coef_, cv.best_estimator_.dual_coef_)
  353. def test_grid_search_when_param_grid_includes_range():
  354. # Test that the best estimator contains the right value for foo_param
  355. clf = MockClassifier()
  356. grid_search = None
  357. grid_search = GridSearchCV(clf, {"foo_param": range(1, 4)}, cv=3)
  358. grid_search.fit(X, y)
  359. assert grid_search.best_estimator_.foo_param == 2
  360. def test_grid_search_bad_param_grid():
  361. X, y = make_classification(n_samples=10, n_features=5, random_state=0)
  362. param_dict = {"C": 1}
  363. clf = SVC(gamma="auto")
  364. error_msg = re.escape(
  365. "Parameter grid for parameter 'C' needs to be a list or "
  366. "a numpy array, but got 1 (of type int) instead. Single "
  367. "values need to be wrapped in a list with one element."
  368. )
  369. search = GridSearchCV(clf, param_dict)
  370. with pytest.raises(TypeError, match=error_msg):
  371. search.fit(X, y)
  372. param_dict = {"C": []}
  373. clf = SVC()
  374. error_msg = re.escape(
  375. "Parameter grid for parameter 'C' need to be a non-empty sequence, got: []"
  376. )
  377. search = GridSearchCV(clf, param_dict)
  378. with pytest.raises(ValueError, match=error_msg):
  379. search.fit(X, y)
  380. param_dict = {"C": "1,2,3"}
  381. clf = SVC(gamma="auto")
  382. error_msg = re.escape(
  383. "Parameter grid for parameter 'C' needs to be a list or a numpy array, "
  384. "but got '1,2,3' (of type str) instead. Single values need to be "
  385. "wrapped in a list with one element."
  386. )
  387. search = GridSearchCV(clf, param_dict)
  388. with pytest.raises(TypeError, match=error_msg):
  389. search.fit(X, y)
  390. param_dict = {"C": np.ones((3, 2))}
  391. clf = SVC()
  392. search = GridSearchCV(clf, param_dict)
  393. with pytest.raises(ValueError):
  394. search.fit(X, y)
  395. def test_grid_search_sparse():
  396. # Test that grid search works with both dense and sparse matrices
  397. X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
  398. clf = LinearSVC(dual="auto")
  399. cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
  400. cv.fit(X_[:180], y_[:180])
  401. y_pred = cv.predict(X_[180:])
  402. C = cv.best_estimator_.C
  403. X_ = sp.csr_matrix(X_)
  404. clf = LinearSVC(dual="auto")
  405. cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
  406. cv.fit(X_[:180].tocoo(), y_[:180])
  407. y_pred2 = cv.predict(X_[180:])
  408. C2 = cv.best_estimator_.C
  409. assert np.mean(y_pred == y_pred2) >= 0.9
  410. assert C == C2
  411. def test_grid_search_sparse_scoring():
  412. X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
  413. clf = LinearSVC(dual="auto")
  414. cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring="f1")
  415. cv.fit(X_[:180], y_[:180])
  416. y_pred = cv.predict(X_[180:])
  417. C = cv.best_estimator_.C
  418. X_ = sp.csr_matrix(X_)
  419. clf = LinearSVC(dual="auto")
  420. cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring="f1")
  421. cv.fit(X_[:180], y_[:180])
  422. y_pred2 = cv.predict(X_[180:])
  423. C2 = cv.best_estimator_.C
  424. assert_array_equal(y_pred, y_pred2)
  425. assert C == C2
  426. # Smoke test the score
  427. # np.testing.assert_allclose(f1_score(cv.predict(X_[:180]), y[:180]),
  428. # cv.score(X_[:180], y[:180]))
  429. # test loss where greater is worse
  430. def f1_loss(y_true_, y_pred_):
  431. return -f1_score(y_true_, y_pred_)
  432. F1Loss = make_scorer(f1_loss, greater_is_better=False)
  433. cv = GridSearchCV(clf, {"C": [0.1, 1.0]}, scoring=F1Loss)
  434. cv.fit(X_[:180], y_[:180])
  435. y_pred3 = cv.predict(X_[180:])
  436. C3 = cv.best_estimator_.C
  437. assert C == C3
  438. assert_array_equal(y_pred, y_pred3)
  439. def test_grid_search_precomputed_kernel():
  440. # Test that grid search works when the input features are given in the
  441. # form of a precomputed kernel matrix
  442. X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
  443. # compute the training kernel matrix corresponding to the linear kernel
  444. K_train = np.dot(X_[:180], X_[:180].T)
  445. y_train = y_[:180]
  446. clf = SVC(kernel="precomputed")
  447. cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
  448. cv.fit(K_train, y_train)
  449. assert cv.best_score_ >= 0
  450. # compute the test kernel matrix
  451. K_test = np.dot(X_[180:], X_[:180].T)
  452. y_test = y_[180:]
  453. y_pred = cv.predict(K_test)
  454. assert np.mean(y_pred == y_test) >= 0
  455. # test error is raised when the precomputed kernel is not array-like
  456. # or sparse
  457. with pytest.raises(ValueError):
  458. cv.fit(K_train.tolist(), y_train)
  459. def test_grid_search_precomputed_kernel_error_nonsquare():
  460. # Test that grid search returns an error with a non-square precomputed
  461. # training kernel matrix
  462. K_train = np.zeros((10, 20))
  463. y_train = np.ones((10,))
  464. clf = SVC(kernel="precomputed")
  465. cv = GridSearchCV(clf, {"C": [0.1, 1.0]})
  466. with pytest.raises(ValueError):
  467. cv.fit(K_train, y_train)
  468. class BrokenClassifier(BaseEstimator):
  469. """Broken classifier that cannot be fit twice"""
  470. def __init__(self, parameter=None):
  471. self.parameter = parameter
  472. def fit(self, X, y):
  473. assert not hasattr(self, "has_been_fit_")
  474. self.has_been_fit_ = True
  475. def predict(self, X):
  476. return np.zeros(X.shape[0])
  477. @ignore_warnings
  478. def test_refit():
  479. # Regression test for bug in refitting
  480. # Simulates re-fitting a broken estimator; this used to break with
  481. # sparse SVMs.
  482. X = np.arange(100).reshape(10, 10)
  483. y = np.array([0] * 5 + [1] * 5)
  484. clf = GridSearchCV(
  485. BrokenClassifier(), [{"parameter": [0, 1]}], scoring="precision", refit=True
  486. )
  487. clf.fit(X, y)
  488. def test_refit_callable():
  489. """
  490. Test refit=callable, which adds flexibility in identifying the
  491. "best" estimator.
  492. """
  493. def refit_callable(cv_results):
  494. """
  495. A dummy function tests `refit=callable` interface.
  496. Return the index of a model that has the least
  497. `mean_test_score`.
  498. """
  499. # Fit a dummy clf with `refit=True` to get a list of keys in
  500. # clf.cv_results_.
  501. X, y = make_classification(n_samples=100, n_features=4, random_state=42)
  502. clf = GridSearchCV(
  503. LinearSVC(dual="auto", random_state=42),
  504. {"C": [0.01, 0.1, 1]},
  505. scoring="precision",
  506. refit=True,
  507. )
  508. clf.fit(X, y)
  509. # Ensure that `best_index_ != 0` for this dummy clf
  510. assert clf.best_index_ != 0
  511. # Assert every key matches those in `cv_results`
  512. for key in clf.cv_results_.keys():
  513. assert key in cv_results
  514. return cv_results["mean_test_score"].argmin()
  515. X, y = make_classification(n_samples=100, n_features=4, random_state=42)
  516. clf = GridSearchCV(
  517. LinearSVC(dual="auto", random_state=42),
  518. {"C": [0.01, 0.1, 1]},
  519. scoring="precision",
  520. refit=refit_callable,
  521. )
  522. clf.fit(X, y)
  523. assert clf.best_index_ == 0
  524. # Ensure `best_score_` is disabled when using `refit=callable`
  525. assert not hasattr(clf, "best_score_")
  526. def test_refit_callable_invalid_type():
  527. """
  528. Test implementation catches the errors when 'best_index_' returns an
  529. invalid result.
  530. """
  531. def refit_callable_invalid_type(cv_results):
  532. """
  533. A dummy function tests when returned 'best_index_' is not integer.
  534. """
  535. return None
  536. X, y = make_classification(n_samples=100, n_features=4, random_state=42)
  537. clf = GridSearchCV(
  538. LinearSVC(dual="auto", random_state=42),
  539. {"C": [0.1, 1]},
  540. scoring="precision",
  541. refit=refit_callable_invalid_type,
  542. )
  543. with pytest.raises(TypeError, match="best_index_ returned is not an integer"):
  544. clf.fit(X, y)
  545. @pytest.mark.parametrize("out_bound_value", [-1, 2])
  546. @pytest.mark.parametrize("search_cv", [RandomizedSearchCV, GridSearchCV])
  547. def test_refit_callable_out_bound(out_bound_value, search_cv):
  548. """
  549. Test implementation catches the errors when 'best_index_' returns an
  550. out of bound result.
  551. """
  552. def refit_callable_out_bound(cv_results):
  553. """
  554. A dummy function tests when returned 'best_index_' is out of bounds.
  555. """
  556. return out_bound_value
  557. X, y = make_classification(n_samples=100, n_features=4, random_state=42)
  558. clf = search_cv(
  559. LinearSVC(dual="auto", random_state=42),
  560. {"C": [0.1, 1]},
  561. scoring="precision",
  562. refit=refit_callable_out_bound,
  563. )
  564. with pytest.raises(IndexError, match="best_index_ index out of range"):
  565. clf.fit(X, y)
  566. def test_refit_callable_multi_metric():
  567. """
  568. Test refit=callable in multiple metric evaluation setting
  569. """
  570. def refit_callable(cv_results):
  571. """
  572. A dummy function tests `refit=callable` interface.
  573. Return the index of a model that has the least
  574. `mean_test_prec`.
  575. """
  576. assert "mean_test_prec" in cv_results
  577. return cv_results["mean_test_prec"].argmin()
  578. X, y = make_classification(n_samples=100, n_features=4, random_state=42)
  579. scoring = {"Accuracy": make_scorer(accuracy_score), "prec": "precision"}
  580. clf = GridSearchCV(
  581. LinearSVC(dual="auto", random_state=42),
  582. {"C": [0.01, 0.1, 1]},
  583. scoring=scoring,
  584. refit=refit_callable,
  585. )
  586. clf.fit(X, y)
  587. assert clf.best_index_ == 0
  588. # Ensure `best_score_` is disabled when using `refit=callable`
  589. assert not hasattr(clf, "best_score_")
  590. def test_gridsearch_nd():
  591. # Pass X as list in GridSearchCV
  592. X_4d = np.arange(10 * 5 * 3 * 2).reshape(10, 5, 3, 2)
  593. y_3d = np.arange(10 * 7 * 11).reshape(10, 7, 11)
  594. def check_X(x):
  595. return x.shape[1:] == (5, 3, 2)
  596. def check_y(x):
  597. return x.shape[1:] == (7, 11)
  598. clf = CheckingClassifier(
  599. check_X=check_X,
  600. check_y=check_y,
  601. methods_to_check=["fit"],
  602. )
  603. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]})
  604. grid_search.fit(X_4d, y_3d).score(X, y)
  605. assert hasattr(grid_search, "cv_results_")
  606. def test_X_as_list():
  607. # Pass X as list in GridSearchCV
  608. X = np.arange(100).reshape(10, 10)
  609. y = np.array([0] * 5 + [1] * 5)
  610. clf = CheckingClassifier(
  611. check_X=lambda x: isinstance(x, list),
  612. methods_to_check=["fit"],
  613. )
  614. cv = KFold(n_splits=3)
  615. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=cv)
  616. grid_search.fit(X.tolist(), y).score(X, y)
  617. assert hasattr(grid_search, "cv_results_")
  618. def test_y_as_list():
  619. # Pass y as list in GridSearchCV
  620. X = np.arange(100).reshape(10, 10)
  621. y = np.array([0] * 5 + [1] * 5)
  622. clf = CheckingClassifier(
  623. check_y=lambda x: isinstance(x, list),
  624. methods_to_check=["fit"],
  625. )
  626. cv = KFold(n_splits=3)
  627. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=cv)
  628. grid_search.fit(X, y.tolist()).score(X, y)
  629. assert hasattr(grid_search, "cv_results_")
  630. @ignore_warnings
  631. def test_pandas_input():
  632. # check cross_val_score doesn't destroy pandas dataframe
  633. types = [(MockDataFrame, MockDataFrame)]
  634. try:
  635. from pandas import DataFrame, Series
  636. types.append((DataFrame, Series))
  637. except ImportError:
  638. pass
  639. X = np.arange(100).reshape(10, 10)
  640. y = np.array([0] * 5 + [1] * 5)
  641. for InputFeatureType, TargetType in types:
  642. # X dataframe, y series
  643. X_df, y_ser = InputFeatureType(X), TargetType(y)
  644. def check_df(x):
  645. return isinstance(x, InputFeatureType)
  646. def check_series(x):
  647. return isinstance(x, TargetType)
  648. clf = CheckingClassifier(check_X=check_df, check_y=check_series)
  649. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]})
  650. grid_search.fit(X_df, y_ser).score(X_df, y_ser)
  651. grid_search.predict(X_df)
  652. assert hasattr(grid_search, "cv_results_")
  653. def test_unsupervised_grid_search():
  654. # test grid-search with unsupervised estimator
  655. X, y = make_blobs(n_samples=50, random_state=0)
  656. km = KMeans(random_state=0, init="random", n_init=1)
  657. # Multi-metric evaluation unsupervised
  658. scoring = ["adjusted_rand_score", "fowlkes_mallows_score"]
  659. for refit in ["adjusted_rand_score", "fowlkes_mallows_score"]:
  660. grid_search = GridSearchCV(
  661. km, param_grid=dict(n_clusters=[2, 3, 4]), scoring=scoring, refit=refit
  662. )
  663. grid_search.fit(X, y)
  664. # Both ARI and FMS can find the right number :)
  665. assert grid_search.best_params_["n_clusters"] == 3
  666. # Single metric evaluation unsupervised
  667. grid_search = GridSearchCV(
  668. km, param_grid=dict(n_clusters=[2, 3, 4]), scoring="fowlkes_mallows_score"
  669. )
  670. grid_search.fit(X, y)
  671. assert grid_search.best_params_["n_clusters"] == 3
  672. # Now without a score, and without y
  673. grid_search = GridSearchCV(km, param_grid=dict(n_clusters=[2, 3, 4]))
  674. grid_search.fit(X)
  675. assert grid_search.best_params_["n_clusters"] == 4
  676. def test_gridsearch_no_predict():
  677. # test grid-search with an estimator without predict.
  678. # slight duplication of a test from KDE
  679. def custom_scoring(estimator, X):
  680. return 42 if estimator.bandwidth == 0.1 else 0
  681. X, _ = make_blobs(cluster_std=0.1, random_state=1, centers=[[0, 1], [1, 0], [0, 0]])
  682. search = GridSearchCV(
  683. KernelDensity(),
  684. param_grid=dict(bandwidth=[0.01, 0.1, 1]),
  685. scoring=custom_scoring,
  686. )
  687. search.fit(X)
  688. assert search.best_params_["bandwidth"] == 0.1
  689. assert search.best_score_ == 42
  690. def test_param_sampler():
  691. # test basic properties of param sampler
  692. param_distributions = {"kernel": ["rbf", "linear"], "C": uniform(0, 1)}
  693. sampler = ParameterSampler(
  694. param_distributions=param_distributions, n_iter=10, random_state=0
  695. )
  696. samples = [x for x in sampler]
  697. assert len(samples) == 10
  698. for sample in samples:
  699. assert sample["kernel"] in ["rbf", "linear"]
  700. assert 0 <= sample["C"] <= 1
  701. # test that repeated calls yield identical parameters
  702. param_distributions = {"C": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
  703. sampler = ParameterSampler(
  704. param_distributions=param_distributions, n_iter=3, random_state=0
  705. )
  706. assert [x for x in sampler] == [x for x in sampler]
  707. param_distributions = {"C": uniform(0, 1)}
  708. sampler = ParameterSampler(
  709. param_distributions=param_distributions, n_iter=10, random_state=0
  710. )
  711. assert [x for x in sampler] == [x for x in sampler]
  712. def check_cv_results_array_types(search, param_keys, score_keys):
  713. # Check if the search `cv_results`'s array are of correct types
  714. cv_results = search.cv_results_
  715. assert all(isinstance(cv_results[param], np.ma.MaskedArray) for param in param_keys)
  716. assert all(cv_results[key].dtype == object for key in param_keys)
  717. assert not any(isinstance(cv_results[key], np.ma.MaskedArray) for key in score_keys)
  718. assert all(
  719. cv_results[key].dtype == np.float64
  720. for key in score_keys
  721. if not key.startswith("rank")
  722. )
  723. scorer_keys = search.scorer_.keys() if search.multimetric_ else ["score"]
  724. for key in scorer_keys:
  725. assert cv_results["rank_test_%s" % key].dtype == np.int32
  726. def check_cv_results_keys(cv_results, param_keys, score_keys, n_cand, extra_keys=()):
  727. # Test the search.cv_results_ contains all the required results
  728. all_keys = param_keys + score_keys + extra_keys
  729. assert_array_equal(sorted(cv_results.keys()), sorted(all_keys + ("params",)))
  730. assert all(cv_results[key].shape == (n_cand,) for key in param_keys + score_keys)
  731. def test_grid_search_cv_results():
  732. X, y = make_classification(n_samples=50, n_features=4, random_state=42)
  733. n_grid_points = 6
  734. params = [
  735. dict(
  736. kernel=[
  737. "rbf",
  738. ],
  739. C=[1, 10],
  740. gamma=[0.1, 1],
  741. ),
  742. dict(
  743. kernel=[
  744. "poly",
  745. ],
  746. degree=[1, 2],
  747. ),
  748. ]
  749. param_keys = ("param_C", "param_degree", "param_gamma", "param_kernel")
  750. score_keys = (
  751. "mean_test_score",
  752. "mean_train_score",
  753. "rank_test_score",
  754. "split0_test_score",
  755. "split1_test_score",
  756. "split2_test_score",
  757. "split0_train_score",
  758. "split1_train_score",
  759. "split2_train_score",
  760. "std_test_score",
  761. "std_train_score",
  762. "mean_fit_time",
  763. "std_fit_time",
  764. "mean_score_time",
  765. "std_score_time",
  766. )
  767. n_candidates = n_grid_points
  768. search = GridSearchCV(SVC(), cv=3, param_grid=params, return_train_score=True)
  769. search.fit(X, y)
  770. cv_results = search.cv_results_
  771. # Check if score and timing are reasonable
  772. assert all(cv_results["rank_test_score"] >= 1)
  773. assert (all(cv_results[k] >= 0) for k in score_keys if k != "rank_test_score")
  774. assert (
  775. all(cv_results[k] <= 1)
  776. for k in score_keys
  777. if "time" not in k and k != "rank_test_score"
  778. )
  779. # Check cv_results structure
  780. check_cv_results_array_types(search, param_keys, score_keys)
  781. check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
  782. # Check masking
  783. cv_results = search.cv_results_
  784. poly_results = [
  785. (
  786. cv_results["param_C"].mask[i]
  787. and cv_results["param_gamma"].mask[i]
  788. and not cv_results["param_degree"].mask[i]
  789. )
  790. for i in range(n_candidates)
  791. if cv_results["param_kernel"][i] == "poly"
  792. ]
  793. assert all(poly_results)
  794. assert len(poly_results) == 2
  795. rbf_results = [
  796. (
  797. not cv_results["param_C"].mask[i]
  798. and not cv_results["param_gamma"].mask[i]
  799. and cv_results["param_degree"].mask[i]
  800. )
  801. for i in range(n_candidates)
  802. if cv_results["param_kernel"][i] == "rbf"
  803. ]
  804. assert all(rbf_results)
  805. assert len(rbf_results) == 4
  806. def test_random_search_cv_results():
  807. X, y = make_classification(n_samples=50, n_features=4, random_state=42)
  808. n_search_iter = 30
  809. params = [
  810. {"kernel": ["rbf"], "C": expon(scale=10), "gamma": expon(scale=0.1)},
  811. {"kernel": ["poly"], "degree": [2, 3]},
  812. ]
  813. param_keys = ("param_C", "param_degree", "param_gamma", "param_kernel")
  814. score_keys = (
  815. "mean_test_score",
  816. "mean_train_score",
  817. "rank_test_score",
  818. "split0_test_score",
  819. "split1_test_score",
  820. "split2_test_score",
  821. "split0_train_score",
  822. "split1_train_score",
  823. "split2_train_score",
  824. "std_test_score",
  825. "std_train_score",
  826. "mean_fit_time",
  827. "std_fit_time",
  828. "mean_score_time",
  829. "std_score_time",
  830. )
  831. n_candidates = n_search_iter
  832. search = RandomizedSearchCV(
  833. SVC(),
  834. n_iter=n_search_iter,
  835. cv=3,
  836. param_distributions=params,
  837. return_train_score=True,
  838. )
  839. search.fit(X, y)
  840. cv_results = search.cv_results_
  841. # Check results structure
  842. check_cv_results_array_types(search, param_keys, score_keys)
  843. check_cv_results_keys(cv_results, param_keys, score_keys, n_candidates)
  844. assert all(
  845. (
  846. cv_results["param_C"].mask[i]
  847. and cv_results["param_gamma"].mask[i]
  848. and not cv_results["param_degree"].mask[i]
  849. )
  850. for i in range(n_candidates)
  851. if cv_results["param_kernel"][i] == "poly"
  852. )
  853. assert all(
  854. (
  855. not cv_results["param_C"].mask[i]
  856. and not cv_results["param_gamma"].mask[i]
  857. and cv_results["param_degree"].mask[i]
  858. )
  859. for i in range(n_candidates)
  860. if cv_results["param_kernel"][i] == "rbf"
  861. )
  862. @pytest.mark.parametrize(
  863. "SearchCV, specialized_params",
  864. [
  865. (GridSearchCV, {"param_grid": {"C": [1, 10]}}),
  866. (RandomizedSearchCV, {"param_distributions": {"C": [1, 10]}, "n_iter": 2}),
  867. ],
  868. )
  869. def test_search_default_iid(SearchCV, specialized_params):
  870. # Test the IID parameter TODO: Clearly this test does something else???
  871. # noise-free simple 2d-data
  872. X, y = make_blobs(
  873. centers=[[0, 0], [1, 0], [0, 1], [1, 1]],
  874. random_state=0,
  875. cluster_std=0.1,
  876. shuffle=False,
  877. n_samples=80,
  878. )
  879. # split dataset into two folds that are not iid
  880. # first one contains data of all 4 blobs, second only from two.
  881. mask = np.ones(X.shape[0], dtype=bool)
  882. mask[np.where(y == 1)[0][::2]] = 0
  883. mask[np.where(y == 2)[0][::2]] = 0
  884. # this leads to perfect classification on one fold and a score of 1/3 on
  885. # the other
  886. # create "cv" for splits
  887. cv = [[mask, ~mask], [~mask, mask]]
  888. common_params = {"estimator": SVC(), "cv": cv, "return_train_score": True}
  889. search = SearchCV(**common_params, **specialized_params)
  890. search.fit(X, y)
  891. test_cv_scores = np.array(
  892. [
  893. search.cv_results_["split%d_test_score" % s][0]
  894. for s in range(search.n_splits_)
  895. ]
  896. )
  897. test_mean = search.cv_results_["mean_test_score"][0]
  898. test_std = search.cv_results_["std_test_score"][0]
  899. train_cv_scores = np.array(
  900. [
  901. search.cv_results_["split%d_train_score" % s][0]
  902. for s in range(search.n_splits_)
  903. ]
  904. )
  905. train_mean = search.cv_results_["mean_train_score"][0]
  906. train_std = search.cv_results_["std_train_score"][0]
  907. assert search.cv_results_["param_C"][0] == 1
  908. # scores are the same as above
  909. assert_allclose(test_cv_scores, [1, 1.0 / 3.0])
  910. assert_allclose(train_cv_scores, [1, 1])
  911. # Unweighted mean/std is used
  912. assert test_mean == pytest.approx(np.mean(test_cv_scores))
  913. assert test_std == pytest.approx(np.std(test_cv_scores))
  914. # For the train scores, we do not take a weighted mean irrespective of
  915. # i.i.d. or not
  916. assert train_mean == pytest.approx(1)
  917. assert train_std == pytest.approx(0)
  918. def test_grid_search_cv_results_multimetric():
  919. X, y = make_classification(n_samples=50, n_features=4, random_state=42)
  920. n_splits = 3
  921. params = [
  922. dict(
  923. kernel=[
  924. "rbf",
  925. ],
  926. C=[1, 10],
  927. gamma=[0.1, 1],
  928. ),
  929. dict(
  930. kernel=[
  931. "poly",
  932. ],
  933. degree=[1, 2],
  934. ),
  935. ]
  936. grid_searches = []
  937. for scoring in (
  938. {"accuracy": make_scorer(accuracy_score), "recall": make_scorer(recall_score)},
  939. "accuracy",
  940. "recall",
  941. ):
  942. grid_search = GridSearchCV(
  943. SVC(), cv=n_splits, param_grid=params, scoring=scoring, refit=False
  944. )
  945. grid_search.fit(X, y)
  946. grid_searches.append(grid_search)
  947. compare_cv_results_multimetric_with_single(*grid_searches)
  948. def test_random_search_cv_results_multimetric():
  949. X, y = make_classification(n_samples=50, n_features=4, random_state=42)
  950. n_splits = 3
  951. n_search_iter = 30
  952. # Scipy 0.12's stats dists do not accept seed, hence we use param grid
  953. params = dict(C=np.logspace(-4, 1, 3), gamma=np.logspace(-5, 0, 3, base=0.1))
  954. for refit in (True, False):
  955. random_searches = []
  956. for scoring in (("accuracy", "recall"), "accuracy", "recall"):
  957. # If True, for multi-metric pass refit='accuracy'
  958. if refit:
  959. probability = True
  960. refit = "accuracy" if isinstance(scoring, tuple) else refit
  961. else:
  962. probability = False
  963. clf = SVC(probability=probability, random_state=42)
  964. random_search = RandomizedSearchCV(
  965. clf,
  966. n_iter=n_search_iter,
  967. cv=n_splits,
  968. param_distributions=params,
  969. scoring=scoring,
  970. refit=refit,
  971. random_state=0,
  972. )
  973. random_search.fit(X, y)
  974. random_searches.append(random_search)
  975. compare_cv_results_multimetric_with_single(*random_searches)
  976. compare_refit_methods_when_refit_with_acc(
  977. random_searches[0], random_searches[1], refit
  978. )
  979. def compare_cv_results_multimetric_with_single(search_multi, search_acc, search_rec):
  980. """Compare multi-metric cv_results with the ensemble of multiple
  981. single metric cv_results from single metric grid/random search"""
  982. assert search_multi.multimetric_
  983. assert_array_equal(sorted(search_multi.scorer_), ("accuracy", "recall"))
  984. cv_results_multi = search_multi.cv_results_
  985. cv_results_acc_rec = {
  986. re.sub("_score$", "_accuracy", k): v for k, v in search_acc.cv_results_.items()
  987. }
  988. cv_results_acc_rec.update(
  989. {re.sub("_score$", "_recall", k): v for k, v in search_rec.cv_results_.items()}
  990. )
  991. # Check if score and timing are reasonable, also checks if the keys
  992. # are present
  993. assert all(
  994. (
  995. np.all(cv_results_multi[k] <= 1)
  996. for k in (
  997. "mean_score_time",
  998. "std_score_time",
  999. "mean_fit_time",
  1000. "std_fit_time",
  1001. )
  1002. )
  1003. )
  1004. # Compare the keys, other than time keys, among multi-metric and
  1005. # single metric grid search results. np.testing.assert_equal performs a
  1006. # deep nested comparison of the two cv_results dicts
  1007. np.testing.assert_equal(
  1008. {k: v for k, v in cv_results_multi.items() if not k.endswith("_time")},
  1009. {k: v for k, v in cv_results_acc_rec.items() if not k.endswith("_time")},
  1010. )
  1011. def compare_refit_methods_when_refit_with_acc(search_multi, search_acc, refit):
  1012. """Compare refit multi-metric search methods with single metric methods"""
  1013. assert search_acc.refit == refit
  1014. if refit:
  1015. assert search_multi.refit == "accuracy"
  1016. else:
  1017. assert not search_multi.refit
  1018. return # search cannot predict/score without refit
  1019. X, y = make_blobs(n_samples=100, n_features=4, random_state=42)
  1020. for method in ("predict", "predict_proba", "predict_log_proba"):
  1021. assert_almost_equal(
  1022. getattr(search_multi, method)(X), getattr(search_acc, method)(X)
  1023. )
  1024. assert_almost_equal(search_multi.score(X, y), search_acc.score(X, y))
  1025. for key in ("best_index_", "best_score_", "best_params_"):
  1026. assert getattr(search_multi, key) == getattr(search_acc, key)
  1027. @pytest.mark.parametrize(
  1028. "search_cv",
  1029. [
  1030. RandomizedSearchCV(
  1031. estimator=DecisionTreeClassifier(),
  1032. param_distributions={"max_depth": [5, 10]},
  1033. ),
  1034. GridSearchCV(
  1035. estimator=DecisionTreeClassifier(), param_grid={"max_depth": [5, 10]}
  1036. ),
  1037. ],
  1038. )
  1039. def test_search_cv_score_samples_error(search_cv):
  1040. X, y = make_blobs(n_samples=100, n_features=4, random_state=42)
  1041. search_cv.fit(X, y)
  1042. # Make sure to error out when underlying estimator does not implement
  1043. # the method `score_samples`
  1044. err_msg = "'DecisionTreeClassifier' object has no attribute 'score_samples'"
  1045. with pytest.raises(AttributeError, match=err_msg):
  1046. search_cv.score_samples(X)
  1047. @pytest.mark.parametrize(
  1048. "search_cv",
  1049. [
  1050. RandomizedSearchCV(
  1051. estimator=LocalOutlierFactor(novelty=True),
  1052. param_distributions={"n_neighbors": [5, 10]},
  1053. scoring="precision",
  1054. ),
  1055. GridSearchCV(
  1056. estimator=LocalOutlierFactor(novelty=True),
  1057. param_grid={"n_neighbors": [5, 10]},
  1058. scoring="precision",
  1059. ),
  1060. ],
  1061. )
  1062. def test_search_cv_score_samples_method(search_cv):
  1063. # Set parameters
  1064. rng = np.random.RandomState(42)
  1065. n_samples = 300
  1066. outliers_fraction = 0.15
  1067. n_outliers = int(outliers_fraction * n_samples)
  1068. n_inliers = n_samples - n_outliers
  1069. # Create dataset
  1070. X = make_blobs(
  1071. n_samples=n_inliers,
  1072. n_features=2,
  1073. centers=[[0, 0], [0, 0]],
  1074. cluster_std=0.5,
  1075. random_state=0,
  1076. )[0]
  1077. # Add some noisy points
  1078. X = np.concatenate([X, rng.uniform(low=-6, high=6, size=(n_outliers, 2))], axis=0)
  1079. # Define labels to be able to score the estimator with `search_cv`
  1080. y_true = np.array([1] * n_samples)
  1081. y_true[-n_outliers:] = -1
  1082. # Fit on data
  1083. search_cv.fit(X, y_true)
  1084. # Verify that the stand alone estimator yields the same results
  1085. # as the ones obtained with *SearchCV
  1086. assert_allclose(
  1087. search_cv.score_samples(X), search_cv.best_estimator_.score_samples(X)
  1088. )
  1089. def test_search_cv_results_rank_tie_breaking():
  1090. X, y = make_blobs(n_samples=50, random_state=42)
  1091. # The two C values are close enough to give similar models
  1092. # which would result in a tie of their mean cv-scores
  1093. param_grid = {"C": [1, 1.001, 0.001]}
  1094. grid_search = GridSearchCV(SVC(), param_grid=param_grid, return_train_score=True)
  1095. random_search = RandomizedSearchCV(
  1096. SVC(), n_iter=3, param_distributions=param_grid, return_train_score=True
  1097. )
  1098. for search in (grid_search, random_search):
  1099. search.fit(X, y)
  1100. cv_results = search.cv_results_
  1101. # Check tie breaking strategy -
  1102. # Check that there is a tie in the mean scores between
  1103. # candidates 1 and 2 alone
  1104. assert_almost_equal(
  1105. cv_results["mean_test_score"][0], cv_results["mean_test_score"][1]
  1106. )
  1107. assert_almost_equal(
  1108. cv_results["mean_train_score"][0], cv_results["mean_train_score"][1]
  1109. )
  1110. assert not np.allclose(
  1111. cv_results["mean_test_score"][1], cv_results["mean_test_score"][2]
  1112. )
  1113. assert not np.allclose(
  1114. cv_results["mean_train_score"][1], cv_results["mean_train_score"][2]
  1115. )
  1116. # 'min' rank should be assigned to the tied candidates
  1117. assert_almost_equal(search.cv_results_["rank_test_score"], [1, 1, 3])
  1118. def test_search_cv_results_none_param():
  1119. X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1]
  1120. estimators = (DecisionTreeRegressor(), DecisionTreeClassifier())
  1121. est_parameters = {"random_state": [0, None]}
  1122. cv = KFold()
  1123. for est in estimators:
  1124. grid_search = GridSearchCV(
  1125. est,
  1126. est_parameters,
  1127. cv=cv,
  1128. ).fit(X, y)
  1129. assert_array_equal(grid_search.cv_results_["param_random_state"], [0, None])
  1130. @ignore_warnings()
  1131. def test_search_cv_timing():
  1132. svc = LinearSVC(dual="auto", random_state=0)
  1133. X = [
  1134. [
  1135. 1,
  1136. ],
  1137. [
  1138. 2,
  1139. ],
  1140. [
  1141. 3,
  1142. ],
  1143. [
  1144. 4,
  1145. ],
  1146. ]
  1147. y = [0, 1, 1, 0]
  1148. gs = GridSearchCV(svc, {"C": [0, 1]}, cv=2, error_score=0)
  1149. rs = RandomizedSearchCV(svc, {"C": [0, 1]}, cv=2, error_score=0, n_iter=2)
  1150. for search in (gs, rs):
  1151. search.fit(X, y)
  1152. for key in ["mean_fit_time", "std_fit_time"]:
  1153. # NOTE The precision of time.time in windows is not high
  1154. # enough for the fit/score times to be non-zero for trivial X and y
  1155. assert np.all(search.cv_results_[key] >= 0)
  1156. assert np.all(search.cv_results_[key] < 1)
  1157. for key in ["mean_score_time", "std_score_time"]:
  1158. assert search.cv_results_[key][1] >= 0
  1159. assert search.cv_results_[key][0] == 0.0
  1160. assert np.all(search.cv_results_[key] < 1)
  1161. assert hasattr(search, "refit_time_")
  1162. assert isinstance(search.refit_time_, float)
  1163. assert search.refit_time_ >= 0
  1164. def test_grid_search_correct_score_results():
  1165. # test that correct scores are used
  1166. n_splits = 3
  1167. clf = LinearSVC(dual="auto", random_state=0)
  1168. X, y = make_blobs(random_state=0, centers=2)
  1169. Cs = [0.1, 1, 10]
  1170. for score in ["f1", "roc_auc"]:
  1171. grid_search = GridSearchCV(clf, {"C": Cs}, scoring=score, cv=n_splits)
  1172. cv_results = grid_search.fit(X, y).cv_results_
  1173. # Test scorer names
  1174. result_keys = list(cv_results.keys())
  1175. expected_keys = ("mean_test_score", "rank_test_score") + tuple(
  1176. "split%d_test_score" % cv_i for cv_i in range(n_splits)
  1177. )
  1178. assert all(np.isin(expected_keys, result_keys))
  1179. cv = StratifiedKFold(n_splits=n_splits)
  1180. n_splits = grid_search.n_splits_
  1181. for candidate_i, C in enumerate(Cs):
  1182. clf.set_params(C=C)
  1183. cv_scores = np.array(
  1184. [
  1185. grid_search.cv_results_["split%d_test_score" % s][candidate_i]
  1186. for s in range(n_splits)
  1187. ]
  1188. )
  1189. for i, (train, test) in enumerate(cv.split(X, y)):
  1190. clf.fit(X[train], y[train])
  1191. if score == "f1":
  1192. correct_score = f1_score(y[test], clf.predict(X[test]))
  1193. elif score == "roc_auc":
  1194. dec = clf.decision_function(X[test])
  1195. correct_score = roc_auc_score(y[test], dec)
  1196. assert_almost_equal(correct_score, cv_scores[i])
  1197. def test_pickle():
  1198. # Test that a fit search can be pickled
  1199. clf = MockClassifier()
  1200. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, refit=True, cv=3)
  1201. grid_search.fit(X, y)
  1202. grid_search_pickled = pickle.loads(pickle.dumps(grid_search))
  1203. assert_array_almost_equal(grid_search.predict(X), grid_search_pickled.predict(X))
  1204. random_search = RandomizedSearchCV(
  1205. clf, {"foo_param": [1, 2, 3]}, refit=True, n_iter=3, cv=3
  1206. )
  1207. random_search.fit(X, y)
  1208. random_search_pickled = pickle.loads(pickle.dumps(random_search))
  1209. assert_array_almost_equal(
  1210. random_search.predict(X), random_search_pickled.predict(X)
  1211. )
  1212. def test_grid_search_with_multioutput_data():
  1213. # Test search with multi-output estimator
  1214. X, y = make_multilabel_classification(return_indicator=True, random_state=0)
  1215. est_parameters = {"max_depth": [1, 2, 3, 4]}
  1216. cv = KFold()
  1217. estimators = [
  1218. DecisionTreeRegressor(random_state=0),
  1219. DecisionTreeClassifier(random_state=0),
  1220. ]
  1221. # Test with grid search cv
  1222. for est in estimators:
  1223. grid_search = GridSearchCV(est, est_parameters, cv=cv)
  1224. grid_search.fit(X, y)
  1225. res_params = grid_search.cv_results_["params"]
  1226. for cand_i in range(len(res_params)):
  1227. est.set_params(**res_params[cand_i])
  1228. for i, (train, test) in enumerate(cv.split(X, y)):
  1229. est.fit(X[train], y[train])
  1230. correct_score = est.score(X[test], y[test])
  1231. assert_almost_equal(
  1232. correct_score,
  1233. grid_search.cv_results_["split%d_test_score" % i][cand_i],
  1234. )
  1235. # Test with a randomized search
  1236. for est in estimators:
  1237. random_search = RandomizedSearchCV(est, est_parameters, cv=cv, n_iter=3)
  1238. random_search.fit(X, y)
  1239. res_params = random_search.cv_results_["params"]
  1240. for cand_i in range(len(res_params)):
  1241. est.set_params(**res_params[cand_i])
  1242. for i, (train, test) in enumerate(cv.split(X, y)):
  1243. est.fit(X[train], y[train])
  1244. correct_score = est.score(X[test], y[test])
  1245. assert_almost_equal(
  1246. correct_score,
  1247. random_search.cv_results_["split%d_test_score" % i][cand_i],
  1248. )
  1249. def test_predict_proba_disabled():
  1250. # Test predict_proba when disabled on estimator.
  1251. X = np.arange(20).reshape(5, -1)
  1252. y = [0, 0, 1, 1, 1]
  1253. clf = SVC(probability=False)
  1254. gs = GridSearchCV(clf, {}, cv=2).fit(X, y)
  1255. assert not hasattr(gs, "predict_proba")
  1256. def test_grid_search_allows_nans():
  1257. # Test GridSearchCV with SimpleImputer
  1258. X = np.arange(20, dtype=np.float64).reshape(5, -1)
  1259. X[2, :] = np.nan
  1260. y = [0, 0, 1, 1, 1]
  1261. p = Pipeline(
  1262. [
  1263. ("imputer", SimpleImputer(strategy="mean", missing_values=np.nan)),
  1264. ("classifier", MockClassifier()),
  1265. ]
  1266. )
  1267. GridSearchCV(p, {"classifier__foo_param": [1, 2, 3]}, cv=2).fit(X, y)
  1268. class FailingClassifier(BaseEstimator):
  1269. """Classifier that raises a ValueError on fit()"""
  1270. FAILING_PARAMETER = 2
  1271. def __init__(self, parameter=None):
  1272. self.parameter = parameter
  1273. def fit(self, X, y=None):
  1274. if self.parameter == FailingClassifier.FAILING_PARAMETER:
  1275. raise ValueError("Failing classifier failed as required")
  1276. def predict(self, X):
  1277. return np.zeros(X.shape[0])
  1278. def score(self, X=None, Y=None):
  1279. return 0.0
  1280. def test_grid_search_failing_classifier():
  1281. # GridSearchCV with on_error != 'raise'
  1282. # Ensures that a warning is raised and score reset where appropriate.
  1283. X, y = make_classification(n_samples=20, n_features=10, random_state=0)
  1284. clf = FailingClassifier()
  1285. # refit=False because we only want to check that errors caused by fits
  1286. # to individual folds will be caught and warnings raised instead. If
  1287. # refit was done, then an exception would be raised on refit and not
  1288. # caught by grid_search (expected behavior), and this would cause an
  1289. # error in this test.
  1290. gs = GridSearchCV(
  1291. clf,
  1292. [{"parameter": [0, 1, 2]}],
  1293. scoring="accuracy",
  1294. refit=False,
  1295. error_score=0.0,
  1296. )
  1297. warning_message = re.compile(
  1298. "5 fits failed.+total of 15.+The score on these"
  1299. r" train-test partitions for these parameters will be set to 0\.0.+"
  1300. "5 fits failed with the following error.+ValueError.+Failing classifier failed"
  1301. " as required",
  1302. flags=re.DOTALL,
  1303. )
  1304. with pytest.warns(FitFailedWarning, match=warning_message):
  1305. gs.fit(X, y)
  1306. n_candidates = len(gs.cv_results_["params"])
  1307. # Ensure that grid scores were set to zero as required for those fits
  1308. # that are expected to fail.
  1309. def get_cand_scores(i):
  1310. return np.array(
  1311. [gs.cv_results_["split%d_test_score" % s][i] for s in range(gs.n_splits_)]
  1312. )
  1313. assert all(
  1314. (
  1315. np.all(get_cand_scores(cand_i) == 0.0)
  1316. for cand_i in range(n_candidates)
  1317. if gs.cv_results_["param_parameter"][cand_i]
  1318. == FailingClassifier.FAILING_PARAMETER
  1319. )
  1320. )
  1321. gs = GridSearchCV(
  1322. clf,
  1323. [{"parameter": [0, 1, 2]}],
  1324. scoring="accuracy",
  1325. refit=False,
  1326. error_score=float("nan"),
  1327. )
  1328. warning_message = re.compile(
  1329. "5 fits failed.+total of 15.+The score on these"
  1330. r" train-test partitions for these parameters will be set to nan.+"
  1331. "5 fits failed with the following error.+ValueError.+Failing classifier failed"
  1332. " as required",
  1333. flags=re.DOTALL,
  1334. )
  1335. with pytest.warns(FitFailedWarning, match=warning_message):
  1336. gs.fit(X, y)
  1337. n_candidates = len(gs.cv_results_["params"])
  1338. assert all(
  1339. np.all(np.isnan(get_cand_scores(cand_i)))
  1340. for cand_i in range(n_candidates)
  1341. if gs.cv_results_["param_parameter"][cand_i]
  1342. == FailingClassifier.FAILING_PARAMETER
  1343. )
  1344. ranks = gs.cv_results_["rank_test_score"]
  1345. # Check that succeeded estimators have lower ranks
  1346. assert ranks[0] <= 2 and ranks[1] <= 2
  1347. # Check that failed estimator has the highest rank
  1348. assert ranks[clf.FAILING_PARAMETER] == 3
  1349. assert gs.best_index_ != clf.FAILING_PARAMETER
  1350. def test_grid_search_classifier_all_fits_fail():
  1351. X, y = make_classification(n_samples=20, n_features=10, random_state=0)
  1352. clf = FailingClassifier()
  1353. gs = GridSearchCV(
  1354. clf,
  1355. [{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
  1356. error_score=0.0,
  1357. )
  1358. warning_message = re.compile(
  1359. (
  1360. "All the 15 fits failed.+15 fits failed with the following"
  1361. " error.+ValueError.+Failing classifier failed as required"
  1362. ),
  1363. flags=re.DOTALL,
  1364. )
  1365. with pytest.raises(ValueError, match=warning_message):
  1366. gs.fit(X, y)
  1367. def test_grid_search_failing_classifier_raise():
  1368. # GridSearchCV with on_error == 'raise' raises the error
  1369. X, y = make_classification(n_samples=20, n_features=10, random_state=0)
  1370. clf = FailingClassifier()
  1371. # refit=False because we want to test the behaviour of the grid search part
  1372. gs = GridSearchCV(
  1373. clf,
  1374. [{"parameter": [0, 1, 2]}],
  1375. scoring="accuracy",
  1376. refit=False,
  1377. error_score="raise",
  1378. )
  1379. # FailingClassifier issues a ValueError so this is what we look for.
  1380. with pytest.raises(ValueError):
  1381. gs.fit(X, y)
  1382. def test_parameters_sampler_replacement():
  1383. # raise warning if n_iter is bigger than total parameter space
  1384. params = [
  1385. {"first": [0, 1], "second": ["a", "b", "c"]},
  1386. {"third": ["two", "values"]},
  1387. ]
  1388. sampler = ParameterSampler(params, n_iter=9)
  1389. n_iter = 9
  1390. grid_size = 8
  1391. expected_warning = (
  1392. "The total space of parameters %d is smaller "
  1393. "than n_iter=%d. Running %d iterations. For "
  1394. "exhaustive searches, use GridSearchCV." % (grid_size, n_iter, grid_size)
  1395. )
  1396. with pytest.warns(UserWarning, match=expected_warning):
  1397. list(sampler)
  1398. # degenerates to GridSearchCV if n_iter the same as grid_size
  1399. sampler = ParameterSampler(params, n_iter=8)
  1400. samples = list(sampler)
  1401. assert len(samples) == 8
  1402. for values in ParameterGrid(params):
  1403. assert values in samples
  1404. assert len(ParameterSampler(params, n_iter=1000)) == 8
  1405. # test sampling without replacement in a large grid
  1406. params = {"a": range(10), "b": range(10), "c": range(10)}
  1407. sampler = ParameterSampler(params, n_iter=99, random_state=42)
  1408. samples = list(sampler)
  1409. assert len(samples) == 99
  1410. hashable_samples = ["a%db%dc%d" % (p["a"], p["b"], p["c"]) for p in samples]
  1411. assert len(set(hashable_samples)) == 99
  1412. # doesn't go into infinite loops
  1413. params_distribution = {"first": bernoulli(0.5), "second": ["a", "b", "c"]}
  1414. sampler = ParameterSampler(params_distribution, n_iter=7)
  1415. samples = list(sampler)
  1416. assert len(samples) == 7
  1417. def test_stochastic_gradient_loss_param():
  1418. # Make sure the predict_proba works when loss is specified
  1419. # as one of the parameters in the param_grid.
  1420. param_grid = {
  1421. "loss": ["log_loss"],
  1422. }
  1423. X = np.arange(24).reshape(6, -1)
  1424. y = [0, 0, 0, 1, 1, 1]
  1425. clf = GridSearchCV(
  1426. estimator=SGDClassifier(loss="hinge"), param_grid=param_grid, cv=3
  1427. )
  1428. # When the estimator is not fitted, `predict_proba` is not available as the
  1429. # loss is 'hinge'.
  1430. assert not hasattr(clf, "predict_proba")
  1431. clf.fit(X, y)
  1432. clf.predict_proba(X)
  1433. clf.predict_log_proba(X)
  1434. # Make sure `predict_proba` is not available when setting loss=['hinge']
  1435. # in param_grid
  1436. param_grid = {
  1437. "loss": ["hinge"],
  1438. }
  1439. clf = GridSearchCV(
  1440. estimator=SGDClassifier(loss="hinge"), param_grid=param_grid, cv=3
  1441. )
  1442. assert not hasattr(clf, "predict_proba")
  1443. clf.fit(X, y)
  1444. assert not hasattr(clf, "predict_proba")
  1445. def test_search_train_scores_set_to_false():
  1446. X = np.arange(6).reshape(6, -1)
  1447. y = [0, 0, 0, 1, 1, 1]
  1448. clf = LinearSVC(dual="auto", random_state=0)
  1449. gs = GridSearchCV(clf, param_grid={"C": [0.1, 0.2]}, cv=3)
  1450. gs.fit(X, y)
  1451. def test_grid_search_cv_splits_consistency():
  1452. # Check if a one time iterable is accepted as a cv parameter.
  1453. n_samples = 100
  1454. n_splits = 5
  1455. X, y = make_classification(n_samples=n_samples, random_state=0)
  1456. gs = GridSearchCV(
  1457. LinearSVC(dual="auto", random_state=0),
  1458. param_grid={"C": [0.1, 0.2, 0.3]},
  1459. cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples),
  1460. return_train_score=True,
  1461. )
  1462. gs.fit(X, y)
  1463. gs2 = GridSearchCV(
  1464. LinearSVC(dual="auto", random_state=0),
  1465. param_grid={"C": [0.1, 0.2, 0.3]},
  1466. cv=KFold(n_splits=n_splits),
  1467. return_train_score=True,
  1468. )
  1469. gs2.fit(X, y)
  1470. # Give generator as a cv parameter
  1471. assert isinstance(
  1472. KFold(n_splits=n_splits, shuffle=True, random_state=0).split(X, y),
  1473. GeneratorType,
  1474. )
  1475. gs3 = GridSearchCV(
  1476. LinearSVC(dual="auto", random_state=0),
  1477. param_grid={"C": [0.1, 0.2, 0.3]},
  1478. cv=KFold(n_splits=n_splits, shuffle=True, random_state=0).split(X, y),
  1479. return_train_score=True,
  1480. )
  1481. gs3.fit(X, y)
  1482. gs4 = GridSearchCV(
  1483. LinearSVC(dual="auto", random_state=0),
  1484. param_grid={"C": [0.1, 0.2, 0.3]},
  1485. cv=KFold(n_splits=n_splits, shuffle=True, random_state=0),
  1486. return_train_score=True,
  1487. )
  1488. gs4.fit(X, y)
  1489. def _pop_time_keys(cv_results):
  1490. for key in (
  1491. "mean_fit_time",
  1492. "std_fit_time",
  1493. "mean_score_time",
  1494. "std_score_time",
  1495. ):
  1496. cv_results.pop(key)
  1497. return cv_results
  1498. # Check if generators are supported as cv and
  1499. # that the splits are consistent
  1500. np.testing.assert_equal(
  1501. _pop_time_keys(gs3.cv_results_), _pop_time_keys(gs4.cv_results_)
  1502. )
  1503. # OneTimeSplitter is a non-re-entrant cv where split can be called only
  1504. # once if ``cv.split`` is called once per param setting in GridSearchCV.fit
  1505. # the 2nd and 3rd parameter will not be evaluated as no train/test indices
  1506. # will be generated for the 2nd and subsequent cv.split calls.
  1507. # This is a check to make sure cv.split is not called once per param
  1508. # setting.
  1509. np.testing.assert_equal(
  1510. {k: v for k, v in gs.cv_results_.items() if not k.endswith("_time")},
  1511. {k: v for k, v in gs2.cv_results_.items() if not k.endswith("_time")},
  1512. )
  1513. # Check consistency of folds across the parameters
  1514. gs = GridSearchCV(
  1515. LinearSVC(dual="auto", random_state=0),
  1516. param_grid={"C": [0.1, 0.1, 0.2, 0.2]},
  1517. cv=KFold(n_splits=n_splits, shuffle=True),
  1518. return_train_score=True,
  1519. )
  1520. gs.fit(X, y)
  1521. # As the first two param settings (C=0.1) and the next two param
  1522. # settings (C=0.2) are same, the test and train scores must also be
  1523. # same as long as the same train/test indices are generated for all
  1524. # the cv splits, for both param setting
  1525. for score_type in ("train", "test"):
  1526. per_param_scores = {}
  1527. for param_i in range(4):
  1528. per_param_scores[param_i] = [
  1529. gs.cv_results_["split%d_%s_score" % (s, score_type)][param_i]
  1530. for s in range(5)
  1531. ]
  1532. assert_array_almost_equal(per_param_scores[0], per_param_scores[1])
  1533. assert_array_almost_equal(per_param_scores[2], per_param_scores[3])
  1534. def test_transform_inverse_transform_round_trip():
  1535. clf = MockClassifier()
  1536. grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
  1537. grid_search.fit(X, y)
  1538. X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
  1539. assert_array_equal(X, X_round_trip)
  1540. def test_custom_run_search():
  1541. def check_results(results, gscv):
  1542. exp_results = gscv.cv_results_
  1543. assert sorted(results.keys()) == sorted(exp_results)
  1544. for k in results:
  1545. if not k.endswith("_time"):
  1546. # XXX: results['params'] is a list :|
  1547. results[k] = np.asanyarray(results[k])
  1548. if results[k].dtype.kind == "O":
  1549. assert_array_equal(
  1550. exp_results[k], results[k], err_msg="Checking " + k
  1551. )
  1552. else:
  1553. assert_allclose(exp_results[k], results[k], err_msg="Checking " + k)
  1554. def fit_grid(param_grid):
  1555. return GridSearchCV(clf, param_grid, return_train_score=True).fit(X, y)
  1556. class CustomSearchCV(BaseSearchCV):
  1557. def __init__(self, estimator, **kwargs):
  1558. super().__init__(estimator, **kwargs)
  1559. def _run_search(self, evaluate):
  1560. results = evaluate([{"max_depth": 1}, {"max_depth": 2}])
  1561. check_results(results, fit_grid({"max_depth": [1, 2]}))
  1562. results = evaluate([{"min_samples_split": 5}, {"min_samples_split": 10}])
  1563. check_results(
  1564. results,
  1565. fit_grid([{"max_depth": [1, 2]}, {"min_samples_split": [5, 10]}]),
  1566. )
  1567. # Using regressor to make sure each score differs
  1568. clf = DecisionTreeRegressor(random_state=0)
  1569. X, y = make_classification(n_samples=100, n_informative=4, random_state=0)
  1570. mycv = CustomSearchCV(clf, return_train_score=True).fit(X, y)
  1571. gscv = fit_grid([{"max_depth": [1, 2]}, {"min_samples_split": [5, 10]}])
  1572. results = mycv.cv_results_
  1573. check_results(results, gscv)
  1574. for attr in dir(gscv):
  1575. if (
  1576. attr[0].islower()
  1577. and attr[-1:] == "_"
  1578. and attr
  1579. not in {
  1580. "cv_results_",
  1581. "best_estimator_",
  1582. "refit_time_",
  1583. "classes_",
  1584. "scorer_",
  1585. }
  1586. ):
  1587. assert getattr(gscv, attr) == getattr(mycv, attr), (
  1588. "Attribute %s not equal" % attr
  1589. )
  1590. def test__custom_fit_no_run_search():
  1591. class NoRunSearchSearchCV(BaseSearchCV):
  1592. def __init__(self, estimator, **kwargs):
  1593. super().__init__(estimator, **kwargs)
  1594. def fit(self, X, y=None, groups=None, **fit_params):
  1595. return self
  1596. # this should not raise any exceptions
  1597. NoRunSearchSearchCV(SVC()).fit(X, y)
  1598. class BadSearchCV(BaseSearchCV):
  1599. def __init__(self, estimator, **kwargs):
  1600. super().__init__(estimator, **kwargs)
  1601. with pytest.raises(NotImplementedError, match="_run_search not implemented."):
  1602. # this should raise a NotImplementedError
  1603. BadSearchCV(SVC()).fit(X, y)
  1604. def test_empty_cv_iterator_error():
  1605. # Use global X, y
  1606. # create cv
  1607. cv = KFold(n_splits=3).split(X)
  1608. # pop all of it, this should cause the expected ValueError
  1609. [u for u in cv]
  1610. # cv is empty now
  1611. train_size = 100
  1612. ridge = RandomizedSearchCV(Ridge(), {"alpha": [1e-3, 1e-2, 1e-1]}, cv=cv, n_jobs=4)
  1613. # assert that this raises an error
  1614. with pytest.raises(
  1615. ValueError,
  1616. match=(
  1617. "No fits were performed. "
  1618. "Was the CV iterator empty\\? "
  1619. "Were there no candidates\\?"
  1620. ),
  1621. ):
  1622. ridge.fit(X[:train_size], y[:train_size])
  1623. def test_random_search_bad_cv():
  1624. # Use global X, y
  1625. class BrokenKFold(KFold):
  1626. def get_n_splits(self, *args, **kw):
  1627. return 1
  1628. # create bad cv
  1629. cv = BrokenKFold(n_splits=3)
  1630. train_size = 100
  1631. ridge = RandomizedSearchCV(Ridge(), {"alpha": [1e-3, 1e-2, 1e-1]}, cv=cv, n_jobs=4)
  1632. # assert that this raises an error
  1633. with pytest.raises(
  1634. ValueError,
  1635. match=(
  1636. "cv.split and cv.get_n_splits returned "
  1637. "inconsistent results. Expected \\d+ "
  1638. "splits, got \\d+"
  1639. ),
  1640. ):
  1641. ridge.fit(X[:train_size], y[:train_size])
  1642. @pytest.mark.parametrize("return_train_score", [False, True])
  1643. @pytest.mark.parametrize(
  1644. "SearchCV, specialized_params",
  1645. [
  1646. (GridSearchCV, {"param_grid": {"max_depth": [2, 3, 5, 8]}}),
  1647. (
  1648. RandomizedSearchCV,
  1649. {"param_distributions": {"max_depth": [2, 3, 5, 8]}, "n_iter": 4},
  1650. ),
  1651. ],
  1652. )
  1653. def test_searchcv_raise_warning_with_non_finite_score(
  1654. SearchCV, specialized_params, return_train_score
  1655. ):
  1656. # Non-regression test for:
  1657. # https://github.com/scikit-learn/scikit-learn/issues/10529
  1658. # Check that we raise a UserWarning when a non-finite score is
  1659. # computed in the SearchCV
  1660. X, y = make_classification(n_classes=2, random_state=0)
  1661. class FailingScorer:
  1662. """Scorer that will fail for some split but not all."""
  1663. def __init__(self):
  1664. self.n_counts = 0
  1665. def __call__(self, estimator, X, y):
  1666. self.n_counts += 1
  1667. if self.n_counts % 5 == 0:
  1668. return np.nan
  1669. return 1
  1670. grid = SearchCV(
  1671. DecisionTreeClassifier(),
  1672. scoring=FailingScorer(),
  1673. cv=3,
  1674. return_train_score=return_train_score,
  1675. **specialized_params,
  1676. )
  1677. with pytest.warns(UserWarning) as warn_msg:
  1678. grid.fit(X, y)
  1679. set_with_warning = ["test", "train"] if return_train_score else ["test"]
  1680. assert len(warn_msg) == len(set_with_warning)
  1681. for msg, dataset in zip(warn_msg, set_with_warning):
  1682. assert f"One or more of the {dataset} scores are non-finite" in str(msg.message)
  1683. # all non-finite scores should be equally ranked last
  1684. last_rank = grid.cv_results_["rank_test_score"].max()
  1685. non_finite_mask = np.isnan(grid.cv_results_["mean_test_score"])
  1686. assert_array_equal(grid.cv_results_["rank_test_score"][non_finite_mask], last_rank)
  1687. # all finite scores should be better ranked than the non-finite scores
  1688. assert np.all(grid.cv_results_["rank_test_score"][~non_finite_mask] < last_rank)
  1689. def test_callable_multimetric_confusion_matrix():
  1690. # Test callable with many metrics inserts the correct names and metrics
  1691. # into the search cv object
  1692. def custom_scorer(clf, X, y):
  1693. y_pred = clf.predict(X)
  1694. cm = confusion_matrix(y, y_pred)
  1695. return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
  1696. X, y = make_classification(n_samples=40, n_features=4, random_state=42)
  1697. est = LinearSVC(dual="auto", random_state=42)
  1698. search = GridSearchCV(est, {"C": [0.1, 1]}, scoring=custom_scorer, refit="fp")
  1699. search.fit(X, y)
  1700. score_names = ["tn", "fp", "fn", "tp"]
  1701. for name in score_names:
  1702. assert "mean_test_{}".format(name) in search.cv_results_
  1703. y_pred = search.predict(X)
  1704. cm = confusion_matrix(y, y_pred)
  1705. assert search.score(X, y) == pytest.approx(cm[0, 1])
  1706. def test_callable_multimetric_same_as_list_of_strings():
  1707. # Test callable multimetric is the same as a list of strings
  1708. def custom_scorer(est, X, y):
  1709. y_pred = est.predict(X)
  1710. return {
  1711. "recall": recall_score(y, y_pred),
  1712. "accuracy": accuracy_score(y, y_pred),
  1713. }
  1714. X, y = make_classification(n_samples=40, n_features=4, random_state=42)
  1715. est = LinearSVC(dual="auto", random_state=42)
  1716. search_callable = GridSearchCV(
  1717. est, {"C": [0.1, 1]}, scoring=custom_scorer, refit="recall"
  1718. )
  1719. search_str = GridSearchCV(
  1720. est, {"C": [0.1, 1]}, scoring=["recall", "accuracy"], refit="recall"
  1721. )
  1722. search_callable.fit(X, y)
  1723. search_str.fit(X, y)
  1724. assert search_callable.best_score_ == pytest.approx(search_str.best_score_)
  1725. assert search_callable.best_index_ == search_str.best_index_
  1726. assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y))
  1727. def test_callable_single_metric_same_as_single_string():
  1728. # Tests callable scorer is the same as scoring with a single string
  1729. def custom_scorer(est, X, y):
  1730. y_pred = est.predict(X)
  1731. return recall_score(y, y_pred)
  1732. X, y = make_classification(n_samples=40, n_features=4, random_state=42)
  1733. est = LinearSVC(dual="auto", random_state=42)
  1734. search_callable = GridSearchCV(
  1735. est, {"C": [0.1, 1]}, scoring=custom_scorer, refit=True
  1736. )
  1737. search_str = GridSearchCV(est, {"C": [0.1, 1]}, scoring="recall", refit="recall")
  1738. search_list_str = GridSearchCV(
  1739. est, {"C": [0.1, 1]}, scoring=["recall"], refit="recall"
  1740. )
  1741. search_callable.fit(X, y)
  1742. search_str.fit(X, y)
  1743. search_list_str.fit(X, y)
  1744. assert search_callable.best_score_ == pytest.approx(search_str.best_score_)
  1745. assert search_callable.best_index_ == search_str.best_index_
  1746. assert search_callable.score(X, y) == pytest.approx(search_str.score(X, y))
  1747. assert search_list_str.best_score_ == pytest.approx(search_str.best_score_)
  1748. assert search_list_str.best_index_ == search_str.best_index_
  1749. assert search_list_str.score(X, y) == pytest.approx(search_str.score(X, y))
  1750. def test_callable_multimetric_error_on_invalid_key():
  1751. # Raises when the callable scorer does not return a dict with `refit` key.
  1752. def bad_scorer(est, X, y):
  1753. return {"bad_name": 1}
  1754. X, y = make_classification(n_samples=40, n_features=4, random_state=42)
  1755. clf = GridSearchCV(
  1756. LinearSVC(dual="auto", random_state=42),
  1757. {"C": [0.1, 1]},
  1758. scoring=bad_scorer,
  1759. refit="good_name",
  1760. )
  1761. msg = (
  1762. "For multi-metric scoring, the parameter refit must be set to a "
  1763. "scorer key or a callable to refit"
  1764. )
  1765. with pytest.raises(ValueError, match=msg):
  1766. clf.fit(X, y)
  1767. def test_callable_multimetric_error_failing_clf():
  1768. # Warns when there is an estimator the fails to fit with a float
  1769. # error_score
  1770. def custom_scorer(est, X, y):
  1771. return {"acc": 1}
  1772. X, y = make_classification(n_samples=20, n_features=10, random_state=0)
  1773. clf = FailingClassifier()
  1774. gs = GridSearchCV(
  1775. clf,
  1776. [{"parameter": [0, 1, 2]}],
  1777. scoring=custom_scorer,
  1778. refit=False,
  1779. error_score=0.1,
  1780. )
  1781. warning_message = re.compile(
  1782. "5 fits failed.+total of 15.+The score on these"
  1783. r" train-test partitions for these parameters will be set to 0\.1",
  1784. flags=re.DOTALL,
  1785. )
  1786. with pytest.warns(FitFailedWarning, match=warning_message):
  1787. gs.fit(X, y)
  1788. assert_allclose(gs.cv_results_["mean_test_acc"], [1, 1, 0.1])
  1789. def test_callable_multimetric_clf_all_fits_fail():
  1790. # Warns and raises when all estimator fails to fit.
  1791. def custom_scorer(est, X, y):
  1792. return {"acc": 1}
  1793. X, y = make_classification(n_samples=20, n_features=10, random_state=0)
  1794. clf = FailingClassifier()
  1795. gs = GridSearchCV(
  1796. clf,
  1797. [{"parameter": [FailingClassifier.FAILING_PARAMETER] * 3}],
  1798. scoring=custom_scorer,
  1799. refit=False,
  1800. error_score=0.1,
  1801. )
  1802. individual_fit_error_message = "ValueError: Failing classifier failed as required"
  1803. error_message = re.compile(
  1804. (
  1805. "All the 15 fits failed.+your model is misconfigured.+"
  1806. f"{individual_fit_error_message}"
  1807. ),
  1808. flags=re.DOTALL,
  1809. )
  1810. with pytest.raises(ValueError, match=error_message):
  1811. gs.fit(X, y)
  1812. def test_n_features_in():
  1813. # make sure grid search and random search delegate n_features_in to the
  1814. # best estimator
  1815. n_features = 4
  1816. X, y = make_classification(n_features=n_features)
  1817. gbdt = HistGradientBoostingClassifier()
  1818. param_grid = {"max_iter": [3, 4]}
  1819. gs = GridSearchCV(gbdt, param_grid)
  1820. rs = RandomizedSearchCV(gbdt, param_grid, n_iter=1)
  1821. assert not hasattr(gs, "n_features_in_")
  1822. assert not hasattr(rs, "n_features_in_")
  1823. gs.fit(X, y)
  1824. rs.fit(X, y)
  1825. assert gs.n_features_in_ == n_features
  1826. assert rs.n_features_in_ == n_features
  1827. @pytest.mark.parametrize("pairwise", [True, False])
  1828. def test_search_cv_pairwise_property_delegated_to_base_estimator(pairwise):
  1829. """
  1830. Test implementation of BaseSearchCV has the pairwise tag
  1831. which matches the pairwise tag of its estimator.
  1832. This test make sure pairwise tag is delegated to the base estimator.
  1833. Non-regression test for issue #13920.
  1834. """
  1835. class TestEstimator(BaseEstimator):
  1836. def _more_tags(self):
  1837. return {"pairwise": pairwise}
  1838. est = TestEstimator()
  1839. attr_message = "BaseSearchCV pairwise tag must match estimator"
  1840. cv = GridSearchCV(est, {"n_neighbors": [10]})
  1841. assert pairwise == cv._get_tags()["pairwise"], attr_message
  1842. def test_search_cv__pairwise_property_delegated_to_base_estimator():
  1843. """
  1844. Test implementation of BaseSearchCV has the pairwise property
  1845. which matches the pairwise tag of its estimator.
  1846. This test make sure pairwise tag is delegated to the base estimator.
  1847. Non-regression test for issue #13920.
  1848. """
  1849. class EstimatorPairwise(BaseEstimator):
  1850. def __init__(self, pairwise=True):
  1851. self.pairwise = pairwise
  1852. def _more_tags(self):
  1853. return {"pairwise": self.pairwise}
  1854. est = EstimatorPairwise()
  1855. attr_message = "BaseSearchCV _pairwise property must match estimator"
  1856. for _pairwise_setting in [True, False]:
  1857. est.set_params(pairwise=_pairwise_setting)
  1858. cv = GridSearchCV(est, {"n_neighbors": [10]})
  1859. assert _pairwise_setting == cv._get_tags()["pairwise"], attr_message
  1860. def test_search_cv_pairwise_property_equivalence_of_precomputed():
  1861. """
  1862. Test implementation of BaseSearchCV has the pairwise tag
  1863. which matches the pairwise tag of its estimator.
  1864. This test ensures the equivalence of 'precomputed'.
  1865. Non-regression test for issue #13920.
  1866. """
  1867. n_samples = 50
  1868. n_splits = 2
  1869. X, y = make_classification(n_samples=n_samples, random_state=0)
  1870. grid_params = {"n_neighbors": [10]}
  1871. # defaults to euclidean metric (minkowski p = 2)
  1872. clf = KNeighborsClassifier()
  1873. cv = GridSearchCV(clf, grid_params, cv=n_splits)
  1874. cv.fit(X, y)
  1875. preds_original = cv.predict(X)
  1876. # precompute euclidean metric to validate pairwise is working
  1877. X_precomputed = euclidean_distances(X)
  1878. clf = KNeighborsClassifier(metric="precomputed")
  1879. cv = GridSearchCV(clf, grid_params, cv=n_splits)
  1880. cv.fit(X_precomputed, y)
  1881. preds_precomputed = cv.predict(X_precomputed)
  1882. attr_message = "GridSearchCV not identical with precomputed metric"
  1883. assert (preds_original == preds_precomputed).all(), attr_message
  1884. @pytest.mark.parametrize(
  1885. "SearchCV, param_search",
  1886. [(GridSearchCV, {"a": [0.1, 0.01]}), (RandomizedSearchCV, {"a": uniform(1, 3)})],
  1887. )
  1888. def test_scalar_fit_param(SearchCV, param_search):
  1889. # unofficially sanctioned tolerance for scalar values in fit_params
  1890. # non-regression test for:
  1891. # https://github.com/scikit-learn/scikit-learn/issues/15805
  1892. class TestEstimator(ClassifierMixin, BaseEstimator):
  1893. def __init__(self, a=None):
  1894. self.a = a
  1895. def fit(self, X, y, r=None):
  1896. self.r_ = r
  1897. def predict(self, X):
  1898. return np.zeros(shape=(len(X)))
  1899. model = SearchCV(TestEstimator(), param_search)
  1900. X, y = make_classification(random_state=42)
  1901. model.fit(X, y, r=42)
  1902. assert model.best_estimator_.r_ == 42
  1903. @pytest.mark.parametrize(
  1904. "SearchCV, param_search",
  1905. [
  1906. (GridSearchCV, {"alpha": [0.1, 0.01]}),
  1907. (RandomizedSearchCV, {"alpha": uniform(0.01, 0.1)}),
  1908. ],
  1909. )
  1910. def test_scalar_fit_param_compat(SearchCV, param_search):
  1911. # check support for scalar values in fit_params, for instance in LightGBM
  1912. # that do not exactly respect the scikit-learn API contract but that we do
  1913. # not want to break without an explicit deprecation cycle and API
  1914. # recommendations for implementing early stopping with a user provided
  1915. # validation set. non-regression test for:
  1916. # https://github.com/scikit-learn/scikit-learn/issues/15805
  1917. X_train, X_valid, y_train, y_valid = train_test_split(
  1918. *make_classification(random_state=42), random_state=42
  1919. )
  1920. class _FitParamClassifier(SGDClassifier):
  1921. def fit(
  1922. self,
  1923. X,
  1924. y,
  1925. sample_weight=None,
  1926. tuple_of_arrays=None,
  1927. scalar_param=None,
  1928. callable_param=None,
  1929. ):
  1930. super().fit(X, y, sample_weight=sample_weight)
  1931. assert scalar_param > 0
  1932. assert callable(callable_param)
  1933. # The tuple of arrays should be preserved as tuple.
  1934. assert isinstance(tuple_of_arrays, tuple)
  1935. assert tuple_of_arrays[0].ndim == 2
  1936. assert tuple_of_arrays[1].ndim == 1
  1937. return self
  1938. def _fit_param_callable():
  1939. pass
  1940. model = SearchCV(_FitParamClassifier(), param_search)
  1941. # NOTE: `fit_params` should be data dependent (e.g. `sample_weight`) which
  1942. # is not the case for the following parameters. But this abuse is common in
  1943. # popular third-party libraries and we should tolerate this behavior for
  1944. # now and be careful not to break support for those without following
  1945. # proper deprecation cycle.
  1946. fit_params = {
  1947. "tuple_of_arrays": (X_valid, y_valid),
  1948. "callable_param": _fit_param_callable,
  1949. "scalar_param": 42,
  1950. }
  1951. model.fit(X_train, y_train, **fit_params)
  1952. # FIXME: Replace this test with a full `check_estimator` once we have API only
  1953. # checks.
  1954. @pytest.mark.filterwarnings("ignore:The total space of parameters 4 is")
  1955. @pytest.mark.parametrize("SearchCV", [GridSearchCV, RandomizedSearchCV])
  1956. @pytest.mark.parametrize("Predictor", [MinimalRegressor, MinimalClassifier])
  1957. def test_search_cv_using_minimal_compatible_estimator(SearchCV, Predictor):
  1958. # Check that third-party library can run tests without inheriting from
  1959. # BaseEstimator.
  1960. rng = np.random.RandomState(0)
  1961. X, y = rng.randn(25, 2), np.array([0] * 5 + [1] * 20)
  1962. model = Pipeline(
  1963. [("transformer", MinimalTransformer()), ("predictor", Predictor())]
  1964. )
  1965. params = {
  1966. "transformer__param": [1, 10],
  1967. "predictor__parama": [1, 10],
  1968. }
  1969. search = SearchCV(model, params, error_score="raise")
  1970. search.fit(X, y)
  1971. assert search.best_params_.keys() == params.keys()
  1972. y_pred = search.predict(X)
  1973. if is_classifier(search):
  1974. assert_array_equal(y_pred, 1)
  1975. assert search.score(X, y) == pytest.approx(accuracy_score(y, y_pred))
  1976. else:
  1977. assert_allclose(y_pred, y.mean())
  1978. assert search.score(X, y) == pytest.approx(r2_score(y, y_pred))
  1979. @pytest.mark.parametrize("return_train_score", [True, False])
  1980. def test_search_cv_verbose_3(capsys, return_train_score):
  1981. """Check that search cv with verbose>2 shows the score for single
  1982. metrics. non-regression test for #19658."""
  1983. X, y = make_classification(n_samples=100, n_classes=2, flip_y=0.2, random_state=0)
  1984. clf = LinearSVC(dual="auto", random_state=0)
  1985. grid = {"C": [0.1]}
  1986. GridSearchCV(
  1987. clf,
  1988. grid,
  1989. scoring="accuracy",
  1990. verbose=3,
  1991. cv=3,
  1992. return_train_score=return_train_score,
  1993. ).fit(X, y)
  1994. captured = capsys.readouterr().out
  1995. if return_train_score:
  1996. match = re.findall(r"score=\(train=[\d\.]+, test=[\d.]+\)", captured)
  1997. else:
  1998. match = re.findall(r"score=[\d\.]+", captured)
  1999. assert len(match) == 3
  2000. @pytest.mark.parametrize(
  2001. "SearchCV, param_search",
  2002. [
  2003. (GridSearchCV, "param_grid"),
  2004. (RandomizedSearchCV, "param_distributions"),
  2005. (HalvingGridSearchCV, "param_grid"),
  2006. ],
  2007. )
  2008. def test_search_estimator_param(SearchCV, param_search):
  2009. # test that SearchCV object doesn't change the object given in the parameter grid
  2010. X, y = make_classification(random_state=42)
  2011. params = {"clf": [LinearSVC(dual="auto")], "clf__C": [0.01]}
  2012. orig_C = params["clf"][0].C
  2013. pipe = Pipeline([("trs", MinimalTransformer()), ("clf", None)])
  2014. param_grid_search = {param_search: params}
  2015. gs = SearchCV(pipe, refit=True, cv=2, scoring="accuracy", **param_grid_search).fit(
  2016. X, y
  2017. )
  2018. # testing that the original object in params is not changed
  2019. assert params["clf"][0].C == orig_C
  2020. # testing that the GS is setting the parameter of the step correctly
  2021. assert gs.best_estimator_.named_steps["clf"].C == 0.01