_search.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813
  1. """
  2. The :mod:`sklearn.model_selection._search` includes utilities to fine-tune the
  3. parameters of an estimator.
  4. """
  5. # Author: Alexandre Gramfort <alexandre.gramfort@inria.fr>,
  6. # Gael Varoquaux <gael.varoquaux@normalesup.org>
  7. # Andreas Mueller <amueller@ais.uni-bonn.de>
  8. # Olivier Grisel <olivier.grisel@ensta.org>
  9. # Raghav RV <rvraghav93@gmail.com>
  10. # License: BSD 3 clause
  11. import numbers
  12. import operator
  13. import time
  14. import warnings
  15. from abc import ABCMeta, abstractmethod
  16. from collections import defaultdict
  17. from collections.abc import Iterable, Mapping, Sequence
  18. from functools import partial, reduce
  19. from itertools import product
  20. import numpy as np
  21. from numpy.ma import MaskedArray
  22. from scipy.stats import rankdata
  23. from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone, is_classifier
  24. from ..exceptions import NotFittedError
  25. from ..metrics import check_scoring
  26. from ..metrics._scorer import _check_multimetric_scoring, get_scorer_names
  27. from ..utils import check_random_state
  28. from ..utils._param_validation import HasMethods, Interval, StrOptions
  29. from ..utils._tags import _safe_tags
  30. from ..utils.metaestimators import available_if
  31. from ..utils.parallel import Parallel, delayed
  32. from ..utils.random import sample_without_replacement
  33. from ..utils.validation import _check_fit_params, check_is_fitted, indexable
  34. from ._split import check_cv
  35. from ._validation import (
  36. _aggregate_score_dicts,
  37. _fit_and_score,
  38. _insert_error_scores,
  39. _normalize_score_results,
  40. _warn_or_raise_about_fit_failures,
  41. )
  42. __all__ = ["GridSearchCV", "ParameterGrid", "ParameterSampler", "RandomizedSearchCV"]
  43. class ParameterGrid:
  44. """Grid of parameters with a discrete number of values for each.
  45. Can be used to iterate over parameter value combinations with the
  46. Python built-in function iter.
  47. The order of the generated parameter combinations is deterministic.
  48. Read more in the :ref:`User Guide <grid_search>`.
  49. Parameters
  50. ----------
  51. param_grid : dict of str to sequence, or sequence of such
  52. The parameter grid to explore, as a dictionary mapping estimator
  53. parameters to sequences of allowed values.
  54. An empty dict signifies default parameters.
  55. A sequence of dicts signifies a sequence of grids to search, and is
  56. useful to avoid exploring parameter combinations that make no sense
  57. or have no effect. See the examples below.
  58. Examples
  59. --------
  60. >>> from sklearn.model_selection import ParameterGrid
  61. >>> param_grid = {'a': [1, 2], 'b': [True, False]}
  62. >>> list(ParameterGrid(param_grid)) == (
  63. ... [{'a': 1, 'b': True}, {'a': 1, 'b': False},
  64. ... {'a': 2, 'b': True}, {'a': 2, 'b': False}])
  65. True
  66. >>> grid = [{'kernel': ['linear']}, {'kernel': ['rbf'], 'gamma': [1, 10]}]
  67. >>> list(ParameterGrid(grid)) == [{'kernel': 'linear'},
  68. ... {'kernel': 'rbf', 'gamma': 1},
  69. ... {'kernel': 'rbf', 'gamma': 10}]
  70. True
  71. >>> ParameterGrid(grid)[1] == {'kernel': 'rbf', 'gamma': 1}
  72. True
  73. See Also
  74. --------
  75. GridSearchCV : Uses :class:`ParameterGrid` to perform a full parallelized
  76. parameter search.
  77. """
  78. def __init__(self, param_grid):
  79. if not isinstance(param_grid, (Mapping, Iterable)):
  80. raise TypeError(
  81. f"Parameter grid should be a dict or a list, got: {param_grid!r} of"
  82. f" type {type(param_grid).__name__}"
  83. )
  84. if isinstance(param_grid, Mapping):
  85. # wrap dictionary in a singleton list to support either dict
  86. # or list of dicts
  87. param_grid = [param_grid]
  88. # check if all entries are dictionaries of lists
  89. for grid in param_grid:
  90. if not isinstance(grid, dict):
  91. raise TypeError(f"Parameter grid is not a dict ({grid!r})")
  92. for key, value in grid.items():
  93. if isinstance(value, np.ndarray) and value.ndim > 1:
  94. raise ValueError(
  95. f"Parameter array for {key!r} should be one-dimensional, got:"
  96. f" {value!r} with shape {value.shape}"
  97. )
  98. if isinstance(value, str) or not isinstance(
  99. value, (np.ndarray, Sequence)
  100. ):
  101. raise TypeError(
  102. f"Parameter grid for parameter {key!r} needs to be a list or a"
  103. f" numpy array, but got {value!r} (of type "
  104. f"{type(value).__name__}) instead. Single values "
  105. "need to be wrapped in a list with one element."
  106. )
  107. if len(value) == 0:
  108. raise ValueError(
  109. f"Parameter grid for parameter {key!r} need "
  110. f"to be a non-empty sequence, got: {value!r}"
  111. )
  112. self.param_grid = param_grid
  113. def __iter__(self):
  114. """Iterate over the points in the grid.
  115. Returns
  116. -------
  117. params : iterator over dict of str to any
  118. Yields dictionaries mapping each estimator parameter to one of its
  119. allowed values.
  120. """
  121. for p in self.param_grid:
  122. # Always sort the keys of a dictionary, for reproducibility
  123. items = sorted(p.items())
  124. if not items:
  125. yield {}
  126. else:
  127. keys, values = zip(*items)
  128. for v in product(*values):
  129. params = dict(zip(keys, v))
  130. yield params
  131. def __len__(self):
  132. """Number of points on the grid."""
  133. # Product function that can handle iterables (np.prod can't).
  134. product = partial(reduce, operator.mul)
  135. return sum(
  136. product(len(v) for v in p.values()) if p else 1 for p in self.param_grid
  137. )
  138. def __getitem__(self, ind):
  139. """Get the parameters that would be ``ind``th in iteration
  140. Parameters
  141. ----------
  142. ind : int
  143. The iteration index
  144. Returns
  145. -------
  146. params : dict of str to any
  147. Equal to list(self)[ind]
  148. """
  149. # This is used to make discrete sampling without replacement memory
  150. # efficient.
  151. for sub_grid in self.param_grid:
  152. # XXX: could memoize information used here
  153. if not sub_grid:
  154. if ind == 0:
  155. return {}
  156. else:
  157. ind -= 1
  158. continue
  159. # Reverse so most frequent cycling parameter comes first
  160. keys, values_lists = zip(*sorted(sub_grid.items())[::-1])
  161. sizes = [len(v_list) for v_list in values_lists]
  162. total = np.prod(sizes)
  163. if ind >= total:
  164. # Try the next grid
  165. ind -= total
  166. else:
  167. out = {}
  168. for key, v_list, n in zip(keys, values_lists, sizes):
  169. ind, offset = divmod(ind, n)
  170. out[key] = v_list[offset]
  171. return out
  172. raise IndexError("ParameterGrid index out of range")
  173. class ParameterSampler:
  174. """Generator on parameters sampled from given distributions.
  175. Non-deterministic iterable over random candidate combinations for hyper-
  176. parameter search. If all parameters are presented as a list,
  177. sampling without replacement is performed. If at least one parameter
  178. is given as a distribution, sampling with replacement is used.
  179. It is highly recommended to use continuous distributions for continuous
  180. parameters.
  181. Read more in the :ref:`User Guide <grid_search>`.
  182. Parameters
  183. ----------
  184. param_distributions : dict
  185. Dictionary with parameters names (`str`) as keys and distributions
  186. or lists of parameters to try. Distributions must provide a ``rvs``
  187. method for sampling (such as those from scipy.stats.distributions).
  188. If a list is given, it is sampled uniformly.
  189. If a list of dicts is given, first a dict is sampled uniformly, and
  190. then a parameter is sampled using that dict as above.
  191. n_iter : int
  192. Number of parameter settings that are produced.
  193. random_state : int, RandomState instance or None, default=None
  194. Pseudo random number generator state used for random uniform sampling
  195. from lists of possible values instead of scipy.stats distributions.
  196. Pass an int for reproducible output across multiple
  197. function calls.
  198. See :term:`Glossary <random_state>`.
  199. Returns
  200. -------
  201. params : dict of str to any
  202. **Yields** dictionaries mapping each estimator parameter to
  203. as sampled value.
  204. Examples
  205. --------
  206. >>> from sklearn.model_selection import ParameterSampler
  207. >>> from scipy.stats.distributions import expon
  208. >>> import numpy as np
  209. >>> rng = np.random.RandomState(0)
  210. >>> param_grid = {'a':[1, 2], 'b': expon()}
  211. >>> param_list = list(ParameterSampler(param_grid, n_iter=4,
  212. ... random_state=rng))
  213. >>> rounded_list = [dict((k, round(v, 6)) for (k, v) in d.items())
  214. ... for d in param_list]
  215. >>> rounded_list == [{'b': 0.89856, 'a': 1},
  216. ... {'b': 0.923223, 'a': 1},
  217. ... {'b': 1.878964, 'a': 2},
  218. ... {'b': 1.038159, 'a': 2}]
  219. True
  220. """
  221. def __init__(self, param_distributions, n_iter, *, random_state=None):
  222. if not isinstance(param_distributions, (Mapping, Iterable)):
  223. raise TypeError(
  224. "Parameter distribution is not a dict or a list,"
  225. f" got: {param_distributions!r} of type "
  226. f"{type(param_distributions).__name__}"
  227. )
  228. if isinstance(param_distributions, Mapping):
  229. # wrap dictionary in a singleton list to support either dict
  230. # or list of dicts
  231. param_distributions = [param_distributions]
  232. for dist in param_distributions:
  233. if not isinstance(dist, dict):
  234. raise TypeError(
  235. "Parameter distribution is not a dict ({!r})".format(dist)
  236. )
  237. for key in dist:
  238. if not isinstance(dist[key], Iterable) and not hasattr(
  239. dist[key], "rvs"
  240. ):
  241. raise TypeError(
  242. f"Parameter grid for parameter {key!r} is not iterable "
  243. f"or a distribution (value={dist[key]})"
  244. )
  245. self.n_iter = n_iter
  246. self.random_state = random_state
  247. self.param_distributions = param_distributions
  248. def _is_all_lists(self):
  249. return all(
  250. all(not hasattr(v, "rvs") for v in dist.values())
  251. for dist in self.param_distributions
  252. )
  253. def __iter__(self):
  254. rng = check_random_state(self.random_state)
  255. # if all distributions are given as lists, we want to sample without
  256. # replacement
  257. if self._is_all_lists():
  258. # look up sampled parameter settings in parameter grid
  259. param_grid = ParameterGrid(self.param_distributions)
  260. grid_size = len(param_grid)
  261. n_iter = self.n_iter
  262. if grid_size < n_iter:
  263. warnings.warn(
  264. "The total space of parameters %d is smaller "
  265. "than n_iter=%d. Running %d iterations. For exhaustive "
  266. "searches, use GridSearchCV." % (grid_size, self.n_iter, grid_size),
  267. UserWarning,
  268. )
  269. n_iter = grid_size
  270. for i in sample_without_replacement(grid_size, n_iter, random_state=rng):
  271. yield param_grid[i]
  272. else:
  273. for _ in range(self.n_iter):
  274. dist = rng.choice(self.param_distributions)
  275. # Always sort the keys of a dictionary, for reproducibility
  276. items = sorted(dist.items())
  277. params = dict()
  278. for k, v in items:
  279. if hasattr(v, "rvs"):
  280. params[k] = v.rvs(random_state=rng)
  281. else:
  282. params[k] = v[rng.randint(len(v))]
  283. yield params
  284. def __len__(self):
  285. """Number of points that will be sampled."""
  286. if self._is_all_lists():
  287. grid_size = len(ParameterGrid(self.param_distributions))
  288. return min(self.n_iter, grid_size)
  289. else:
  290. return self.n_iter
  291. def _check_refit(search_cv, attr):
  292. if not search_cv.refit:
  293. raise AttributeError(
  294. f"This {type(search_cv).__name__} instance was initialized with "
  295. f"`refit=False`. {attr} is available only after refitting on the best "
  296. "parameters. You can refit an estimator manually using the "
  297. "`best_params_` attribute"
  298. )
  299. def _estimator_has(attr):
  300. """Check if we can delegate a method to the underlying estimator.
  301. Calling a prediction method will only be available if `refit=True`. In
  302. such case, we check first the fitted best estimator. If it is not
  303. fitted, we check the unfitted estimator.
  304. Checking the unfitted estimator allows to use `hasattr` on the `SearchCV`
  305. instance even before calling `fit`.
  306. """
  307. def check(self):
  308. _check_refit(self, attr)
  309. if hasattr(self, "best_estimator_"):
  310. # raise an AttributeError if `attr` does not exist
  311. getattr(self.best_estimator_, attr)
  312. return True
  313. # raise an AttributeError if `attr` does not exist
  314. getattr(self.estimator, attr)
  315. return True
  316. return check
  317. class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
  318. """Abstract base class for hyper parameter search with cross-validation."""
  319. _parameter_constraints: dict = {
  320. "estimator": [HasMethods(["fit"])],
  321. "scoring": [
  322. StrOptions(set(get_scorer_names())),
  323. callable,
  324. list,
  325. tuple,
  326. dict,
  327. None,
  328. ],
  329. "n_jobs": [numbers.Integral, None],
  330. "refit": ["boolean", str, callable],
  331. "cv": ["cv_object"],
  332. "verbose": ["verbose"],
  333. "pre_dispatch": [numbers.Integral, str],
  334. "error_score": [StrOptions({"raise"}), numbers.Real],
  335. "return_train_score": ["boolean"],
  336. }
  337. @abstractmethod
  338. def __init__(
  339. self,
  340. estimator,
  341. *,
  342. scoring=None,
  343. n_jobs=None,
  344. refit=True,
  345. cv=None,
  346. verbose=0,
  347. pre_dispatch="2*n_jobs",
  348. error_score=np.nan,
  349. return_train_score=True,
  350. ):
  351. self.scoring = scoring
  352. self.estimator = estimator
  353. self.n_jobs = n_jobs
  354. self.refit = refit
  355. self.cv = cv
  356. self.verbose = verbose
  357. self.pre_dispatch = pre_dispatch
  358. self.error_score = error_score
  359. self.return_train_score = return_train_score
  360. @property
  361. def _estimator_type(self):
  362. return self.estimator._estimator_type
  363. def _more_tags(self):
  364. # allows cross-validation to see 'precomputed' metrics
  365. return {
  366. "pairwise": _safe_tags(self.estimator, "pairwise"),
  367. "_xfail_checks": {
  368. "check_supervised_y_2d": "DataConversionWarning not caught"
  369. },
  370. }
  371. def score(self, X, y=None):
  372. """Return the score on the given data, if the estimator has been refit.
  373. This uses the score defined by ``scoring`` where provided, and the
  374. ``best_estimator_.score`` method otherwise.
  375. Parameters
  376. ----------
  377. X : array-like of shape (n_samples, n_features)
  378. Input data, where `n_samples` is the number of samples and
  379. `n_features` is the number of features.
  380. y : array-like of shape (n_samples, n_output) \
  381. or (n_samples,), default=None
  382. Target relative to X for classification or regression;
  383. None for unsupervised learning.
  384. Returns
  385. -------
  386. score : float
  387. The score defined by ``scoring`` if provided, and the
  388. ``best_estimator_.score`` method otherwise.
  389. """
  390. _check_refit(self, "score")
  391. check_is_fitted(self)
  392. if self.scorer_ is None:
  393. raise ValueError(
  394. "No score function explicitly defined, "
  395. "and the estimator doesn't provide one %s"
  396. % self.best_estimator_
  397. )
  398. if isinstance(self.scorer_, dict):
  399. if self.multimetric_:
  400. scorer = self.scorer_[self.refit]
  401. else:
  402. scorer = self.scorer_
  403. return scorer(self.best_estimator_, X, y)
  404. # callable
  405. score = self.scorer_(self.best_estimator_, X, y)
  406. if self.multimetric_:
  407. score = score[self.refit]
  408. return score
  409. @available_if(_estimator_has("score_samples"))
  410. def score_samples(self, X):
  411. """Call score_samples on the estimator with the best found parameters.
  412. Only available if ``refit=True`` and the underlying estimator supports
  413. ``score_samples``.
  414. .. versionadded:: 0.24
  415. Parameters
  416. ----------
  417. X : iterable
  418. Data to predict on. Must fulfill input requirements
  419. of the underlying estimator.
  420. Returns
  421. -------
  422. y_score : ndarray of shape (n_samples,)
  423. The ``best_estimator_.score_samples`` method.
  424. """
  425. check_is_fitted(self)
  426. return self.best_estimator_.score_samples(X)
  427. @available_if(_estimator_has("predict"))
  428. def predict(self, X):
  429. """Call predict on the estimator with the best found parameters.
  430. Only available if ``refit=True`` and the underlying estimator supports
  431. ``predict``.
  432. Parameters
  433. ----------
  434. X : indexable, length n_samples
  435. Must fulfill the input assumptions of the
  436. underlying estimator.
  437. Returns
  438. -------
  439. y_pred : ndarray of shape (n_samples,)
  440. The predicted labels or values for `X` based on the estimator with
  441. the best found parameters.
  442. """
  443. check_is_fitted(self)
  444. return self.best_estimator_.predict(X)
  445. @available_if(_estimator_has("predict_proba"))
  446. def predict_proba(self, X):
  447. """Call predict_proba on the estimator with the best found parameters.
  448. Only available if ``refit=True`` and the underlying estimator supports
  449. ``predict_proba``.
  450. Parameters
  451. ----------
  452. X : indexable, length n_samples
  453. Must fulfill the input assumptions of the
  454. underlying estimator.
  455. Returns
  456. -------
  457. y_pred : ndarray of shape (n_samples,) or (n_samples, n_classes)
  458. Predicted class probabilities for `X` based on the estimator with
  459. the best found parameters. The order of the classes corresponds
  460. to that in the fitted attribute :term:`classes_`.
  461. """
  462. check_is_fitted(self)
  463. return self.best_estimator_.predict_proba(X)
  464. @available_if(_estimator_has("predict_log_proba"))
  465. def predict_log_proba(self, X):
  466. """Call predict_log_proba on the estimator with the best found parameters.
  467. Only available if ``refit=True`` and the underlying estimator supports
  468. ``predict_log_proba``.
  469. Parameters
  470. ----------
  471. X : indexable, length n_samples
  472. Must fulfill the input assumptions of the
  473. underlying estimator.
  474. Returns
  475. -------
  476. y_pred : ndarray of shape (n_samples,) or (n_samples, n_classes)
  477. Predicted class log-probabilities for `X` based on the estimator
  478. with the best found parameters. The order of the classes
  479. corresponds to that in the fitted attribute :term:`classes_`.
  480. """
  481. check_is_fitted(self)
  482. return self.best_estimator_.predict_log_proba(X)
  483. @available_if(_estimator_has("decision_function"))
  484. def decision_function(self, X):
  485. """Call decision_function on the estimator with the best found parameters.
  486. Only available if ``refit=True`` and the underlying estimator supports
  487. ``decision_function``.
  488. Parameters
  489. ----------
  490. X : indexable, length n_samples
  491. Must fulfill the input assumptions of the
  492. underlying estimator.
  493. Returns
  494. -------
  495. y_score : ndarray of shape (n_samples,) or (n_samples, n_classes) \
  496. or (n_samples, n_classes * (n_classes-1) / 2)
  497. Result of the decision function for `X` based on the estimator with
  498. the best found parameters.
  499. """
  500. check_is_fitted(self)
  501. return self.best_estimator_.decision_function(X)
  502. @available_if(_estimator_has("transform"))
  503. def transform(self, X):
  504. """Call transform on the estimator with the best found parameters.
  505. Only available if the underlying estimator supports ``transform`` and
  506. ``refit=True``.
  507. Parameters
  508. ----------
  509. X : indexable, length n_samples
  510. Must fulfill the input assumptions of the
  511. underlying estimator.
  512. Returns
  513. -------
  514. Xt : {ndarray, sparse matrix} of shape (n_samples, n_features)
  515. `X` transformed in the new space based on the estimator with
  516. the best found parameters.
  517. """
  518. check_is_fitted(self)
  519. return self.best_estimator_.transform(X)
  520. @available_if(_estimator_has("inverse_transform"))
  521. def inverse_transform(self, Xt):
  522. """Call inverse_transform on the estimator with the best found params.
  523. Only available if the underlying estimator implements
  524. ``inverse_transform`` and ``refit=True``.
  525. Parameters
  526. ----------
  527. Xt : indexable, length n_samples
  528. Must fulfill the input assumptions of the
  529. underlying estimator.
  530. Returns
  531. -------
  532. X : {ndarray, sparse matrix} of shape (n_samples, n_features)
  533. Result of the `inverse_transform` function for `Xt` based on the
  534. estimator with the best found parameters.
  535. """
  536. check_is_fitted(self)
  537. return self.best_estimator_.inverse_transform(Xt)
  538. @property
  539. def n_features_in_(self):
  540. """Number of features seen during :term:`fit`.
  541. Only available when `refit=True`.
  542. """
  543. # For consistency with other estimators we raise a AttributeError so
  544. # that hasattr() fails if the search estimator isn't fitted.
  545. try:
  546. check_is_fitted(self)
  547. except NotFittedError as nfe:
  548. raise AttributeError(
  549. "{} object has no n_features_in_ attribute.".format(
  550. self.__class__.__name__
  551. )
  552. ) from nfe
  553. return self.best_estimator_.n_features_in_
  554. @property
  555. def classes_(self):
  556. """Class labels.
  557. Only available when `refit=True` and the estimator is a classifier.
  558. """
  559. _estimator_has("classes_")(self)
  560. return self.best_estimator_.classes_
  561. def _run_search(self, evaluate_candidates):
  562. """Repeatedly calls `evaluate_candidates` to conduct a search.
  563. This method, implemented in sub-classes, makes it possible to
  564. customize the scheduling of evaluations: GridSearchCV and
  565. RandomizedSearchCV schedule evaluations for their whole parameter
  566. search space at once but other more sequential approaches are also
  567. possible: for instance is possible to iteratively schedule evaluations
  568. for new regions of the parameter search space based on previously
  569. collected evaluation results. This makes it possible to implement
  570. Bayesian optimization or more generally sequential model-based
  571. optimization by deriving from the BaseSearchCV abstract base class.
  572. For example, Successive Halving is implemented by calling
  573. `evaluate_candidates` multiples times (once per iteration of the SH
  574. process), each time passing a different set of candidates with `X`
  575. and `y` of increasing sizes.
  576. Parameters
  577. ----------
  578. evaluate_candidates : callable
  579. This callback accepts:
  580. - a list of candidates, where each candidate is a dict of
  581. parameter settings.
  582. - an optional `cv` parameter which can be used to e.g.
  583. evaluate candidates on different dataset splits, or
  584. evaluate candidates on subsampled data (as done in the
  585. SucessiveHaling estimators). By default, the original `cv`
  586. parameter is used, and it is available as a private
  587. `_checked_cv_orig` attribute.
  588. - an optional `more_results` dict. Each key will be added to
  589. the `cv_results_` attribute. Values should be lists of
  590. length `n_candidates`
  591. It returns a dict of all results so far, formatted like
  592. ``cv_results_``.
  593. Important note (relevant whether the default cv is used or not):
  594. in randomized splitters, and unless the random_state parameter of
  595. cv was set to an int, calling cv.split() multiple times will
  596. yield different splits. Since cv.split() is called in
  597. evaluate_candidates, this means that candidates will be evaluated
  598. on different splits each time evaluate_candidates is called. This
  599. might be a methodological issue depending on the search strategy
  600. that you're implementing. To prevent randomized splitters from
  601. being used, you may use _split._yields_constant_splits()
  602. Examples
  603. --------
  604. ::
  605. def _run_search(self, evaluate_candidates):
  606. 'Try C=0.1 only if C=1 is better than C=10'
  607. all_results = evaluate_candidates([{'C': 1}, {'C': 10}])
  608. score = all_results['mean_test_score']
  609. if score[0] < score[1]:
  610. evaluate_candidates([{'C': 0.1}])
  611. """
  612. raise NotImplementedError("_run_search not implemented.")
  613. def _check_refit_for_multimetric(self, scores):
  614. """Check `refit` is compatible with `scores` is valid"""
  615. multimetric_refit_msg = (
  616. "For multi-metric scoring, the parameter refit must be set to a "
  617. "scorer key or a callable to refit an estimator with the best "
  618. "parameter setting on the whole data and make the best_* "
  619. "attributes available for that metric. If this is not needed, "
  620. f"refit should be set to False explicitly. {self.refit!r} was "
  621. "passed."
  622. )
  623. valid_refit_dict = isinstance(self.refit, str) and self.refit in scores
  624. if (
  625. self.refit is not False
  626. and not valid_refit_dict
  627. and not callable(self.refit)
  628. ):
  629. raise ValueError(multimetric_refit_msg)
  630. @staticmethod
  631. def _select_best_index(refit, refit_metric, results):
  632. """Select index of the best combination of hyperparemeters."""
  633. if callable(refit):
  634. # If callable, refit is expected to return the index of the best
  635. # parameter set.
  636. best_index = refit(results)
  637. if not isinstance(best_index, numbers.Integral):
  638. raise TypeError("best_index_ returned is not an integer")
  639. if best_index < 0 or best_index >= len(results["params"]):
  640. raise IndexError("best_index_ index out of range")
  641. else:
  642. best_index = results[f"rank_test_{refit_metric}"].argmin()
  643. return best_index
  644. @_fit_context(
  645. # *SearchCV.estimator is not validated yet
  646. prefer_skip_nested_validation=False
  647. )
  648. def fit(self, X, y=None, *, groups=None, **fit_params):
  649. """Run fit with all sets of parameters.
  650. Parameters
  651. ----------
  652. X : array-like of shape (n_samples, n_features)
  653. Training vector, where `n_samples` is the number of samples and
  654. `n_features` is the number of features.
  655. y : array-like of shape (n_samples, n_output) \
  656. or (n_samples,), default=None
  657. Target relative to X for classification or regression;
  658. None for unsupervised learning.
  659. groups : array-like of shape (n_samples,), default=None
  660. Group labels for the samples used while splitting the dataset into
  661. train/test set. Only used in conjunction with a "Group" :term:`cv`
  662. instance (e.g., :class:`~sklearn.model_selection.GroupKFold`).
  663. **fit_params : dict of str -> object
  664. Parameters passed to the `fit` method of the estimator.
  665. If a fit parameter is an array-like whose length is equal to
  666. `num_samples` then it will be split across CV groups along with `X`
  667. and `y`. For example, the :term:`sample_weight` parameter is split
  668. because `len(sample_weights) = len(X)`.
  669. Returns
  670. -------
  671. self : object
  672. Instance of fitted estimator.
  673. """
  674. estimator = self.estimator
  675. refit_metric = "score"
  676. if callable(self.scoring):
  677. scorers = self.scoring
  678. elif self.scoring is None or isinstance(self.scoring, str):
  679. scorers = check_scoring(self.estimator, self.scoring)
  680. else:
  681. scorers = _check_multimetric_scoring(self.estimator, self.scoring)
  682. self._check_refit_for_multimetric(scorers)
  683. refit_metric = self.refit
  684. X, y, groups = indexable(X, y, groups)
  685. fit_params = _check_fit_params(X, fit_params)
  686. cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator))
  687. n_splits = cv_orig.get_n_splits(X, y, groups)
  688. base_estimator = clone(self.estimator)
  689. parallel = Parallel(n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch)
  690. fit_and_score_kwargs = dict(
  691. scorer=scorers,
  692. fit_params=fit_params,
  693. return_train_score=self.return_train_score,
  694. return_n_test_samples=True,
  695. return_times=True,
  696. return_parameters=False,
  697. error_score=self.error_score,
  698. verbose=self.verbose,
  699. )
  700. results = {}
  701. with parallel:
  702. all_candidate_params = []
  703. all_out = []
  704. all_more_results = defaultdict(list)
  705. def evaluate_candidates(candidate_params, cv=None, more_results=None):
  706. cv = cv or cv_orig
  707. candidate_params = list(candidate_params)
  708. n_candidates = len(candidate_params)
  709. if self.verbose > 0:
  710. print(
  711. "Fitting {0} folds for each of {1} candidates,"
  712. " totalling {2} fits".format(
  713. n_splits, n_candidates, n_candidates * n_splits
  714. )
  715. )
  716. out = parallel(
  717. delayed(_fit_and_score)(
  718. clone(base_estimator),
  719. X,
  720. y,
  721. train=train,
  722. test=test,
  723. parameters=parameters,
  724. split_progress=(split_idx, n_splits),
  725. candidate_progress=(cand_idx, n_candidates),
  726. **fit_and_score_kwargs,
  727. )
  728. for (cand_idx, parameters), (split_idx, (train, test)) in product(
  729. enumerate(candidate_params), enumerate(cv.split(X, y, groups))
  730. )
  731. )
  732. if len(out) < 1:
  733. raise ValueError(
  734. "No fits were performed. "
  735. "Was the CV iterator empty? "
  736. "Were there no candidates?"
  737. )
  738. elif len(out) != n_candidates * n_splits:
  739. raise ValueError(
  740. "cv.split and cv.get_n_splits returned "
  741. "inconsistent results. Expected {} "
  742. "splits, got {}".format(n_splits, len(out) // n_candidates)
  743. )
  744. _warn_or_raise_about_fit_failures(out, self.error_score)
  745. # For callable self.scoring, the return type is only know after
  746. # calling. If the return type is a dictionary, the error scores
  747. # can now be inserted with the correct key. The type checking
  748. # of out will be done in `_insert_error_scores`.
  749. if callable(self.scoring):
  750. _insert_error_scores(out, self.error_score)
  751. all_candidate_params.extend(candidate_params)
  752. all_out.extend(out)
  753. if more_results is not None:
  754. for key, value in more_results.items():
  755. all_more_results[key].extend(value)
  756. nonlocal results
  757. results = self._format_results(
  758. all_candidate_params, n_splits, all_out, all_more_results
  759. )
  760. return results
  761. self._run_search(evaluate_candidates)
  762. # multimetric is determined here because in the case of a callable
  763. # self.scoring the return type is only known after calling
  764. first_test_score = all_out[0]["test_scores"]
  765. self.multimetric_ = isinstance(first_test_score, dict)
  766. # check refit_metric now for a callabe scorer that is multimetric
  767. if callable(self.scoring) and self.multimetric_:
  768. self._check_refit_for_multimetric(first_test_score)
  769. refit_metric = self.refit
  770. # For multi-metric evaluation, store the best_index_, best_params_ and
  771. # best_score_ iff refit is one of the scorer names
  772. # In single metric evaluation, refit_metric is "score"
  773. if self.refit or not self.multimetric_:
  774. self.best_index_ = self._select_best_index(
  775. self.refit, refit_metric, results
  776. )
  777. if not callable(self.refit):
  778. # With a non-custom callable, we can select the best score
  779. # based on the best index
  780. self.best_score_ = results[f"mean_test_{refit_metric}"][
  781. self.best_index_
  782. ]
  783. self.best_params_ = results["params"][self.best_index_]
  784. if self.refit:
  785. # here we clone the estimator as well as the parameters, since
  786. # sometimes the parameters themselves might be estimators, e.g.
  787. # when we search over different estimators in a pipeline.
  788. # ref: https://github.com/scikit-learn/scikit-learn/pull/26786
  789. self.best_estimator_ = clone(base_estimator).set_params(
  790. **clone(self.best_params_, safe=False)
  791. )
  792. refit_start_time = time.time()
  793. if y is not None:
  794. self.best_estimator_.fit(X, y, **fit_params)
  795. else:
  796. self.best_estimator_.fit(X, **fit_params)
  797. refit_end_time = time.time()
  798. self.refit_time_ = refit_end_time - refit_start_time
  799. if hasattr(self.best_estimator_, "feature_names_in_"):
  800. self.feature_names_in_ = self.best_estimator_.feature_names_in_
  801. # Store the only scorer not as a dict for single metric evaluation
  802. self.scorer_ = scorers
  803. self.cv_results_ = results
  804. self.n_splits_ = n_splits
  805. return self
  806. def _format_results(self, candidate_params, n_splits, out, more_results=None):
  807. n_candidates = len(candidate_params)
  808. out = _aggregate_score_dicts(out)
  809. results = dict(more_results or {})
  810. for key, val in results.items():
  811. # each value is a list (as per evaluate_candidate's convention)
  812. # we convert it to an array for consistency with the other keys
  813. results[key] = np.asarray(val)
  814. def _store(key_name, array, weights=None, splits=False, rank=False):
  815. """A small helper to store the scores/times to the cv_results_"""
  816. # When iterated first by splits, then by parameters
  817. # We want `array` to have `n_candidates` rows and `n_splits` cols.
  818. array = np.array(array, dtype=np.float64).reshape(n_candidates, n_splits)
  819. if splits:
  820. for split_idx in range(n_splits):
  821. # Uses closure to alter the results
  822. results["split%d_%s" % (split_idx, key_name)] = array[:, split_idx]
  823. array_means = np.average(array, axis=1, weights=weights)
  824. results["mean_%s" % key_name] = array_means
  825. if key_name.startswith(("train_", "test_")) and np.any(
  826. ~np.isfinite(array_means)
  827. ):
  828. warnings.warn(
  829. (
  830. f"One or more of the {key_name.split('_')[0]} scores "
  831. f"are non-finite: {array_means}"
  832. ),
  833. category=UserWarning,
  834. )
  835. # Weighted std is not directly available in numpy
  836. array_stds = np.sqrt(
  837. np.average(
  838. (array - array_means[:, np.newaxis]) ** 2, axis=1, weights=weights
  839. )
  840. )
  841. results["std_%s" % key_name] = array_stds
  842. if rank:
  843. # When the fit/scoring fails `array_means` contains NaNs, we
  844. # will exclude them from the ranking process and consider them
  845. # as tied with the worst performers.
  846. if np.isnan(array_means).all():
  847. # All fit/scoring routines failed.
  848. rank_result = np.ones_like(array_means, dtype=np.int32)
  849. else:
  850. min_array_means = np.nanmin(array_means) - 1
  851. array_means = np.nan_to_num(array_means, nan=min_array_means)
  852. rank_result = rankdata(-array_means, method="min").astype(
  853. np.int32, copy=False
  854. )
  855. results["rank_%s" % key_name] = rank_result
  856. _store("fit_time", out["fit_time"])
  857. _store("score_time", out["score_time"])
  858. # Use one MaskedArray and mask all the places where the param is not
  859. # applicable for that candidate. Use defaultdict as each candidate may
  860. # not contain all the params
  861. param_results = defaultdict(
  862. partial(
  863. MaskedArray,
  864. np.empty(
  865. n_candidates,
  866. ),
  867. mask=True,
  868. dtype=object,
  869. )
  870. )
  871. for cand_idx, params in enumerate(candidate_params):
  872. for name, value in params.items():
  873. # An all masked empty array gets created for the key
  874. # `"param_%s" % name` at the first occurrence of `name`.
  875. # Setting the value at an index also unmasks that index
  876. param_results["param_%s" % name][cand_idx] = value
  877. results.update(param_results)
  878. # Store a list of param dicts at the key 'params'
  879. results["params"] = candidate_params
  880. test_scores_dict = _normalize_score_results(out["test_scores"])
  881. if self.return_train_score:
  882. train_scores_dict = _normalize_score_results(out["train_scores"])
  883. for scorer_name in test_scores_dict:
  884. # Computed the (weighted) mean and std for test scores alone
  885. _store(
  886. "test_%s" % scorer_name,
  887. test_scores_dict[scorer_name],
  888. splits=True,
  889. rank=True,
  890. weights=None,
  891. )
  892. if self.return_train_score:
  893. _store(
  894. "train_%s" % scorer_name,
  895. train_scores_dict[scorer_name],
  896. splits=True,
  897. )
  898. return results
  899. class GridSearchCV(BaseSearchCV):
  900. """Exhaustive search over specified parameter values for an estimator.
  901. Important members are fit, predict.
  902. GridSearchCV implements a "fit" and a "score" method.
  903. It also implements "score_samples", "predict", "predict_proba",
  904. "decision_function", "transform" and "inverse_transform" if they are
  905. implemented in the estimator used.
  906. The parameters of the estimator used to apply these methods are optimized
  907. by cross-validated grid-search over a parameter grid.
  908. Read more in the :ref:`User Guide <grid_search>`.
  909. Parameters
  910. ----------
  911. estimator : estimator object
  912. This is assumed to implement the scikit-learn estimator interface.
  913. Either estimator needs to provide a ``score`` function,
  914. or ``scoring`` must be passed.
  915. param_grid : dict or list of dictionaries
  916. Dictionary with parameters names (`str`) as keys and lists of
  917. parameter settings to try as values, or a list of such
  918. dictionaries, in which case the grids spanned by each dictionary
  919. in the list are explored. This enables searching over any sequence
  920. of parameter settings.
  921. scoring : str, callable, list, tuple or dict, default=None
  922. Strategy to evaluate the performance of the cross-validated model on
  923. the test set.
  924. If `scoring` represents a single score, one can use:
  925. - a single string (see :ref:`scoring_parameter`);
  926. - a callable (see :ref:`scoring`) that returns a single value.
  927. If `scoring` represents multiple scores, one can use:
  928. - a list or tuple of unique strings;
  929. - a callable returning a dictionary where the keys are the metric
  930. names and the values are the metric scores;
  931. - a dictionary with metric names as keys and callables a values.
  932. See :ref:`multimetric_grid_search` for an example.
  933. n_jobs : int, default=None
  934. Number of jobs to run in parallel.
  935. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  936. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  937. for more details.
  938. .. versionchanged:: v0.20
  939. `n_jobs` default changed from 1 to None
  940. refit : bool, str, or callable, default=True
  941. Refit an estimator using the best found parameters on the whole
  942. dataset.
  943. For multiple metric evaluation, this needs to be a `str` denoting the
  944. scorer that would be used to find the best parameters for refitting
  945. the estimator at the end.
  946. Where there are considerations other than maximum score in
  947. choosing a best estimator, ``refit`` can be set to a function which
  948. returns the selected ``best_index_`` given ``cv_results_``. In that
  949. case, the ``best_estimator_`` and ``best_params_`` will be set
  950. according to the returned ``best_index_`` while the ``best_score_``
  951. attribute will not be available.
  952. The refitted estimator is made available at the ``best_estimator_``
  953. attribute and permits using ``predict`` directly on this
  954. ``GridSearchCV`` instance.
  955. Also for multiple metric evaluation, the attributes ``best_index_``,
  956. ``best_score_`` and ``best_params_`` will only be available if
  957. ``refit`` is set and all of them will be determined w.r.t this specific
  958. scorer.
  959. See ``scoring`` parameter to know more about multiple metric
  960. evaluation.
  961. See :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py`
  962. to see how to design a custom selection strategy using a callable
  963. via `refit`.
  964. .. versionchanged:: 0.20
  965. Support for callable added.
  966. cv : int, cross-validation generator or an iterable, default=None
  967. Determines the cross-validation splitting strategy.
  968. Possible inputs for cv are:
  969. - None, to use the default 5-fold cross validation,
  970. - integer, to specify the number of folds in a `(Stratified)KFold`,
  971. - :term:`CV splitter`,
  972. - An iterable yielding (train, test) splits as arrays of indices.
  973. For integer/None inputs, if the estimator is a classifier and ``y`` is
  974. either binary or multiclass, :class:`StratifiedKFold` is used. In all
  975. other cases, :class:`KFold` is used. These splitters are instantiated
  976. with `shuffle=False` so the splits will be the same across calls.
  977. Refer :ref:`User Guide <cross_validation>` for the various
  978. cross-validation strategies that can be used here.
  979. .. versionchanged:: 0.22
  980. ``cv`` default value if None changed from 3-fold to 5-fold.
  981. verbose : int
  982. Controls the verbosity: the higher, the more messages.
  983. - >1 : the computation time for each fold and parameter candidate is
  984. displayed;
  985. - >2 : the score is also displayed;
  986. - >3 : the fold and candidate parameter indexes are also displayed
  987. together with the starting time of the computation.
  988. pre_dispatch : int, or str, default='2*n_jobs'
  989. Controls the number of jobs that get dispatched during parallel
  990. execution. Reducing this number can be useful to avoid an
  991. explosion of memory consumption when more jobs get dispatched
  992. than CPUs can process. This parameter can be:
  993. - None, in which case all the jobs are immediately
  994. created and spawned. Use this for lightweight and
  995. fast-running jobs, to avoid delays due to on-demand
  996. spawning of the jobs
  997. - An int, giving the exact number of total jobs that are
  998. spawned
  999. - A str, giving an expression as a function of n_jobs,
  1000. as in '2*n_jobs'
  1001. error_score : 'raise' or numeric, default=np.nan
  1002. Value to assign to the score if an error occurs in estimator fitting.
  1003. If set to 'raise', the error is raised. If a numeric value is given,
  1004. FitFailedWarning is raised. This parameter does not affect the refit
  1005. step, which will always raise the error.
  1006. return_train_score : bool, default=False
  1007. If ``False``, the ``cv_results_`` attribute will not include training
  1008. scores.
  1009. Computing training scores is used to get insights on how different
  1010. parameter settings impact the overfitting/underfitting trade-off.
  1011. However computing the scores on the training set can be computationally
  1012. expensive and is not strictly required to select the parameters that
  1013. yield the best generalization performance.
  1014. .. versionadded:: 0.19
  1015. .. versionchanged:: 0.21
  1016. Default value was changed from ``True`` to ``False``
  1017. Attributes
  1018. ----------
  1019. cv_results_ : dict of numpy (masked) ndarrays
  1020. A dict with keys as column headers and values as columns, that can be
  1021. imported into a pandas ``DataFrame``.
  1022. For instance the below given table
  1023. +------------+-----------+------------+-----------------+---+---------+
  1024. |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
  1025. +============+===========+============+=================+===+=========+
  1026. | 'poly' | -- | 2 | 0.80 |...| 2 |
  1027. +------------+-----------+------------+-----------------+---+---------+
  1028. | 'poly' | -- | 3 | 0.70 |...| 4 |
  1029. +------------+-----------+------------+-----------------+---+---------+
  1030. | 'rbf' | 0.1 | -- | 0.80 |...| 3 |
  1031. +------------+-----------+------------+-----------------+---+---------+
  1032. | 'rbf' | 0.2 | -- | 0.93 |...| 1 |
  1033. +------------+-----------+------------+-----------------+---+---------+
  1034. will be represented by a ``cv_results_`` dict of::
  1035. {
  1036. 'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
  1037. mask = [False False False False]...)
  1038. 'param_gamma': masked_array(data = [-- -- 0.1 0.2],
  1039. mask = [ True True False False]...),
  1040. 'param_degree': masked_array(data = [2.0 3.0 -- --],
  1041. mask = [False False True True]...),
  1042. 'split0_test_score' : [0.80, 0.70, 0.80, 0.93],
  1043. 'split1_test_score' : [0.82, 0.50, 0.70, 0.78],
  1044. 'mean_test_score' : [0.81, 0.60, 0.75, 0.85],
  1045. 'std_test_score' : [0.01, 0.10, 0.05, 0.08],
  1046. 'rank_test_score' : [2, 4, 3, 1],
  1047. 'split0_train_score' : [0.80, 0.92, 0.70, 0.93],
  1048. 'split1_train_score' : [0.82, 0.55, 0.70, 0.87],
  1049. 'mean_train_score' : [0.81, 0.74, 0.70, 0.90],
  1050. 'std_train_score' : [0.01, 0.19, 0.00, 0.03],
  1051. 'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
  1052. 'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
  1053. 'mean_score_time' : [0.01, 0.06, 0.04, 0.04],
  1054. 'std_score_time' : [0.00, 0.00, 0.00, 0.01],
  1055. 'params' : [{'kernel': 'poly', 'degree': 2}, ...],
  1056. }
  1057. NOTE
  1058. The key ``'params'`` is used to store a list of parameter
  1059. settings dicts for all the parameter candidates.
  1060. The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
  1061. ``std_score_time`` are all in seconds.
  1062. For multi-metric evaluation, the scores for all the scorers are
  1063. available in the ``cv_results_`` dict at the keys ending with that
  1064. scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
  1065. above. ('split0_test_precision', 'mean_train_precision' etc.)
  1066. best_estimator_ : estimator
  1067. Estimator that was chosen by the search, i.e. estimator
  1068. which gave highest score (or smallest loss if specified)
  1069. on the left out data. Not available if ``refit=False``.
  1070. See ``refit`` parameter for more information on allowed values.
  1071. best_score_ : float
  1072. Mean cross-validated score of the best_estimator
  1073. For multi-metric evaluation, this is present only if ``refit`` is
  1074. specified.
  1075. This attribute is not available if ``refit`` is a function.
  1076. best_params_ : dict
  1077. Parameter setting that gave the best results on the hold out data.
  1078. For multi-metric evaluation, this is present only if ``refit`` is
  1079. specified.
  1080. best_index_ : int
  1081. The index (of the ``cv_results_`` arrays) which corresponds to the best
  1082. candidate parameter setting.
  1083. The dict at ``search.cv_results_['params'][search.best_index_]`` gives
  1084. the parameter setting for the best model, that gives the highest
  1085. mean score (``search.best_score_``).
  1086. For multi-metric evaluation, this is present only if ``refit`` is
  1087. specified.
  1088. scorer_ : function or a dict
  1089. Scorer function used on the held out data to choose the best
  1090. parameters for the model.
  1091. For multi-metric evaluation, this attribute holds the validated
  1092. ``scoring`` dict which maps the scorer key to the scorer callable.
  1093. n_splits_ : int
  1094. The number of cross-validation splits (folds/iterations).
  1095. refit_time_ : float
  1096. Seconds used for refitting the best model on the whole dataset.
  1097. This is present only if ``refit`` is not False.
  1098. .. versionadded:: 0.20
  1099. multimetric_ : bool
  1100. Whether or not the scorers compute several metrics.
  1101. classes_ : ndarray of shape (n_classes,)
  1102. The classes labels. This is present only if ``refit`` is specified and
  1103. the underlying estimator is a classifier.
  1104. n_features_in_ : int
  1105. Number of features seen during :term:`fit`. Only defined if
  1106. `best_estimator_` is defined (see the documentation for the `refit`
  1107. parameter for more details) and that `best_estimator_` exposes
  1108. `n_features_in_` when fit.
  1109. .. versionadded:: 0.24
  1110. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1111. Names of features seen during :term:`fit`. Only defined if
  1112. `best_estimator_` is defined (see the documentation for the `refit`
  1113. parameter for more details) and that `best_estimator_` exposes
  1114. `feature_names_in_` when fit.
  1115. .. versionadded:: 1.0
  1116. See Also
  1117. --------
  1118. ParameterGrid : Generates all the combinations of a hyperparameter grid.
  1119. train_test_split : Utility function to split the data into a development
  1120. set usable for fitting a GridSearchCV instance and an evaluation set
  1121. for its final evaluation.
  1122. sklearn.metrics.make_scorer : Make a scorer from a performance metric or
  1123. loss function.
  1124. Notes
  1125. -----
  1126. The parameters selected are those that maximize the score of the left out
  1127. data, unless an explicit score is passed in which case it is used instead.
  1128. If `n_jobs` was set to a value higher than one, the data is copied for each
  1129. point in the grid (and not `n_jobs` times). This is done for efficiency
  1130. reasons if individual jobs take very little time, but may raise errors if
  1131. the dataset is large and not enough memory is available. A workaround in
  1132. this case is to set `pre_dispatch`. Then, the memory is copied only
  1133. `pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
  1134. n_jobs`.
  1135. Examples
  1136. --------
  1137. >>> from sklearn import svm, datasets
  1138. >>> from sklearn.model_selection import GridSearchCV
  1139. >>> iris = datasets.load_iris()
  1140. >>> parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
  1141. >>> svc = svm.SVC()
  1142. >>> clf = GridSearchCV(svc, parameters)
  1143. >>> clf.fit(iris.data, iris.target)
  1144. GridSearchCV(estimator=SVC(),
  1145. param_grid={'C': [1, 10], 'kernel': ('linear', 'rbf')})
  1146. >>> sorted(clf.cv_results_.keys())
  1147. ['mean_fit_time', 'mean_score_time', 'mean_test_score',...
  1148. 'param_C', 'param_kernel', 'params',...
  1149. 'rank_test_score', 'split0_test_score',...
  1150. 'split2_test_score', ...
  1151. 'std_fit_time', 'std_score_time', 'std_test_score']
  1152. """
  1153. _required_parameters = ["estimator", "param_grid"]
  1154. _parameter_constraints: dict = {
  1155. **BaseSearchCV._parameter_constraints,
  1156. "param_grid": [dict, list],
  1157. }
  1158. def __init__(
  1159. self,
  1160. estimator,
  1161. param_grid,
  1162. *,
  1163. scoring=None,
  1164. n_jobs=None,
  1165. refit=True,
  1166. cv=None,
  1167. verbose=0,
  1168. pre_dispatch="2*n_jobs",
  1169. error_score=np.nan,
  1170. return_train_score=False,
  1171. ):
  1172. super().__init__(
  1173. estimator=estimator,
  1174. scoring=scoring,
  1175. n_jobs=n_jobs,
  1176. refit=refit,
  1177. cv=cv,
  1178. verbose=verbose,
  1179. pre_dispatch=pre_dispatch,
  1180. error_score=error_score,
  1181. return_train_score=return_train_score,
  1182. )
  1183. self.param_grid = param_grid
  1184. def _run_search(self, evaluate_candidates):
  1185. """Search all candidates in param_grid"""
  1186. evaluate_candidates(ParameterGrid(self.param_grid))
  1187. class RandomizedSearchCV(BaseSearchCV):
  1188. """Randomized search on hyper parameters.
  1189. RandomizedSearchCV implements a "fit" and a "score" method.
  1190. It also implements "score_samples", "predict", "predict_proba",
  1191. "decision_function", "transform" and "inverse_transform" if they are
  1192. implemented in the estimator used.
  1193. The parameters of the estimator used to apply these methods are optimized
  1194. by cross-validated search over parameter settings.
  1195. In contrast to GridSearchCV, not all parameter values are tried out, but
  1196. rather a fixed number of parameter settings is sampled from the specified
  1197. distributions. The number of parameter settings that are tried is
  1198. given by n_iter.
  1199. If all parameters are presented as a list,
  1200. sampling without replacement is performed. If at least one parameter
  1201. is given as a distribution, sampling with replacement is used.
  1202. It is highly recommended to use continuous distributions for continuous
  1203. parameters.
  1204. Read more in the :ref:`User Guide <randomized_parameter_search>`.
  1205. .. versionadded:: 0.14
  1206. Parameters
  1207. ----------
  1208. estimator : estimator object
  1209. An object of that type is instantiated for each grid point.
  1210. This is assumed to implement the scikit-learn estimator interface.
  1211. Either estimator needs to provide a ``score`` function,
  1212. or ``scoring`` must be passed.
  1213. param_distributions : dict or list of dicts
  1214. Dictionary with parameters names (`str`) as keys and distributions
  1215. or lists of parameters to try. Distributions must provide a ``rvs``
  1216. method for sampling (such as those from scipy.stats.distributions).
  1217. If a list is given, it is sampled uniformly.
  1218. If a list of dicts is given, first a dict is sampled uniformly, and
  1219. then a parameter is sampled using that dict as above.
  1220. n_iter : int, default=10
  1221. Number of parameter settings that are sampled. n_iter trades
  1222. off runtime vs quality of the solution.
  1223. scoring : str, callable, list, tuple or dict, default=None
  1224. Strategy to evaluate the performance of the cross-validated model on
  1225. the test set.
  1226. If `scoring` represents a single score, one can use:
  1227. - a single string (see :ref:`scoring_parameter`);
  1228. - a callable (see :ref:`scoring`) that returns a single value.
  1229. If `scoring` represents multiple scores, one can use:
  1230. - a list or tuple of unique strings;
  1231. - a callable returning a dictionary where the keys are the metric
  1232. names and the values are the metric scores;
  1233. - a dictionary with metric names as keys and callables a values.
  1234. See :ref:`multimetric_grid_search` for an example.
  1235. If None, the estimator's score method is used.
  1236. n_jobs : int, default=None
  1237. Number of jobs to run in parallel.
  1238. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
  1239. ``-1`` means using all processors. See :term:`Glossary <n_jobs>`
  1240. for more details.
  1241. .. versionchanged:: v0.20
  1242. `n_jobs` default changed from 1 to None
  1243. refit : bool, str, or callable, default=True
  1244. Refit an estimator using the best found parameters on the whole
  1245. dataset.
  1246. For multiple metric evaluation, this needs to be a `str` denoting the
  1247. scorer that would be used to find the best parameters for refitting
  1248. the estimator at the end.
  1249. Where there are considerations other than maximum score in
  1250. choosing a best estimator, ``refit`` can be set to a function which
  1251. returns the selected ``best_index_`` given the ``cv_results``. In that
  1252. case, the ``best_estimator_`` and ``best_params_`` will be set
  1253. according to the returned ``best_index_`` while the ``best_score_``
  1254. attribute will not be available.
  1255. The refitted estimator is made available at the ``best_estimator_``
  1256. attribute and permits using ``predict`` directly on this
  1257. ``RandomizedSearchCV`` instance.
  1258. Also for multiple metric evaluation, the attributes ``best_index_``,
  1259. ``best_score_`` and ``best_params_`` will only be available if
  1260. ``refit`` is set and all of them will be determined w.r.t this specific
  1261. scorer.
  1262. See ``scoring`` parameter to know more about multiple metric
  1263. evaluation.
  1264. .. versionchanged:: 0.20
  1265. Support for callable added.
  1266. cv : int, cross-validation generator or an iterable, default=None
  1267. Determines the cross-validation splitting strategy.
  1268. Possible inputs for cv are:
  1269. - None, to use the default 5-fold cross validation,
  1270. - integer, to specify the number of folds in a `(Stratified)KFold`,
  1271. - :term:`CV splitter`,
  1272. - An iterable yielding (train, test) splits as arrays of indices.
  1273. For integer/None inputs, if the estimator is a classifier and ``y`` is
  1274. either binary or multiclass, :class:`StratifiedKFold` is used. In all
  1275. other cases, :class:`KFold` is used. These splitters are instantiated
  1276. with `shuffle=False` so the splits will be the same across calls.
  1277. Refer :ref:`User Guide <cross_validation>` for the various
  1278. cross-validation strategies that can be used here.
  1279. .. versionchanged:: 0.22
  1280. ``cv`` default value if None changed from 3-fold to 5-fold.
  1281. verbose : int
  1282. Controls the verbosity: the higher, the more messages.
  1283. - >1 : the computation time for each fold and parameter candidate is
  1284. displayed;
  1285. - >2 : the score is also displayed;
  1286. - >3 : the fold and candidate parameter indexes are also displayed
  1287. together with the starting time of the computation.
  1288. pre_dispatch : int, or str, default='2*n_jobs'
  1289. Controls the number of jobs that get dispatched during parallel
  1290. execution. Reducing this number can be useful to avoid an
  1291. explosion of memory consumption when more jobs get dispatched
  1292. than CPUs can process. This parameter can be:
  1293. - None, in which case all the jobs are immediately
  1294. created and spawned. Use this for lightweight and
  1295. fast-running jobs, to avoid delays due to on-demand
  1296. spawning of the jobs
  1297. - An int, giving the exact number of total jobs that are
  1298. spawned
  1299. - A str, giving an expression as a function of n_jobs,
  1300. as in '2*n_jobs'
  1301. random_state : int, RandomState instance or None, default=None
  1302. Pseudo random number generator state used for random uniform sampling
  1303. from lists of possible values instead of scipy.stats distributions.
  1304. Pass an int for reproducible output across multiple
  1305. function calls.
  1306. See :term:`Glossary <random_state>`.
  1307. error_score : 'raise' or numeric, default=np.nan
  1308. Value to assign to the score if an error occurs in estimator fitting.
  1309. If set to 'raise', the error is raised. If a numeric value is given,
  1310. FitFailedWarning is raised. This parameter does not affect the refit
  1311. step, which will always raise the error.
  1312. return_train_score : bool, default=False
  1313. If ``False``, the ``cv_results_`` attribute will not include training
  1314. scores.
  1315. Computing training scores is used to get insights on how different
  1316. parameter settings impact the overfitting/underfitting trade-off.
  1317. However computing the scores on the training set can be computationally
  1318. expensive and is not strictly required to select the parameters that
  1319. yield the best generalization performance.
  1320. .. versionadded:: 0.19
  1321. .. versionchanged:: 0.21
  1322. Default value was changed from ``True`` to ``False``
  1323. Attributes
  1324. ----------
  1325. cv_results_ : dict of numpy (masked) ndarrays
  1326. A dict with keys as column headers and values as columns, that can be
  1327. imported into a pandas ``DataFrame``.
  1328. For instance the below given table
  1329. +--------------+-------------+-------------------+---+---------------+
  1330. | param_kernel | param_gamma | split0_test_score |...|rank_test_score|
  1331. +==============+=============+===================+===+===============+
  1332. | 'rbf' | 0.1 | 0.80 |...| 1 |
  1333. +--------------+-------------+-------------------+---+---------------+
  1334. | 'rbf' | 0.2 | 0.84 |...| 3 |
  1335. +--------------+-------------+-------------------+---+---------------+
  1336. | 'rbf' | 0.3 | 0.70 |...| 2 |
  1337. +--------------+-------------+-------------------+---+---------------+
  1338. will be represented by a ``cv_results_`` dict of::
  1339. {
  1340. 'param_kernel' : masked_array(data = ['rbf', 'rbf', 'rbf'],
  1341. mask = False),
  1342. 'param_gamma' : masked_array(data = [0.1 0.2 0.3], mask = False),
  1343. 'split0_test_score' : [0.80, 0.84, 0.70],
  1344. 'split1_test_score' : [0.82, 0.50, 0.70],
  1345. 'mean_test_score' : [0.81, 0.67, 0.70],
  1346. 'std_test_score' : [0.01, 0.24, 0.00],
  1347. 'rank_test_score' : [1, 3, 2],
  1348. 'split0_train_score' : [0.80, 0.92, 0.70],
  1349. 'split1_train_score' : [0.82, 0.55, 0.70],
  1350. 'mean_train_score' : [0.81, 0.74, 0.70],
  1351. 'std_train_score' : [0.01, 0.19, 0.00],
  1352. 'mean_fit_time' : [0.73, 0.63, 0.43],
  1353. 'std_fit_time' : [0.01, 0.02, 0.01],
  1354. 'mean_score_time' : [0.01, 0.06, 0.04],
  1355. 'std_score_time' : [0.00, 0.00, 0.00],
  1356. 'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
  1357. }
  1358. NOTE
  1359. The key ``'params'`` is used to store a list of parameter
  1360. settings dicts for all the parameter candidates.
  1361. The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
  1362. ``std_score_time`` are all in seconds.
  1363. For multi-metric evaluation, the scores for all the scorers are
  1364. available in the ``cv_results_`` dict at the keys ending with that
  1365. scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
  1366. above. ('split0_test_precision', 'mean_train_precision' etc.)
  1367. best_estimator_ : estimator
  1368. Estimator that was chosen by the search, i.e. estimator
  1369. which gave highest score (or smallest loss if specified)
  1370. on the left out data. Not available if ``refit=False``.
  1371. For multi-metric evaluation, this attribute is present only if
  1372. ``refit`` is specified.
  1373. See ``refit`` parameter for more information on allowed values.
  1374. best_score_ : float
  1375. Mean cross-validated score of the best_estimator.
  1376. For multi-metric evaluation, this is not available if ``refit`` is
  1377. ``False``. See ``refit`` parameter for more information.
  1378. This attribute is not available if ``refit`` is a function.
  1379. best_params_ : dict
  1380. Parameter setting that gave the best results on the hold out data.
  1381. For multi-metric evaluation, this is not available if ``refit`` is
  1382. ``False``. See ``refit`` parameter for more information.
  1383. best_index_ : int
  1384. The index (of the ``cv_results_`` arrays) which corresponds to the best
  1385. candidate parameter setting.
  1386. The dict at ``search.cv_results_['params'][search.best_index_]`` gives
  1387. the parameter setting for the best model, that gives the highest
  1388. mean score (``search.best_score_``).
  1389. For multi-metric evaluation, this is not available if ``refit`` is
  1390. ``False``. See ``refit`` parameter for more information.
  1391. scorer_ : function or a dict
  1392. Scorer function used on the held out data to choose the best
  1393. parameters for the model.
  1394. For multi-metric evaluation, this attribute holds the validated
  1395. ``scoring`` dict which maps the scorer key to the scorer callable.
  1396. n_splits_ : int
  1397. The number of cross-validation splits (folds/iterations).
  1398. refit_time_ : float
  1399. Seconds used for refitting the best model on the whole dataset.
  1400. This is present only if ``refit`` is not False.
  1401. .. versionadded:: 0.20
  1402. multimetric_ : bool
  1403. Whether or not the scorers compute several metrics.
  1404. classes_ : ndarray of shape (n_classes,)
  1405. The classes labels. This is present only if ``refit`` is specified and
  1406. the underlying estimator is a classifier.
  1407. n_features_in_ : int
  1408. Number of features seen during :term:`fit`. Only defined if
  1409. `best_estimator_` is defined (see the documentation for the `refit`
  1410. parameter for more details) and that `best_estimator_` exposes
  1411. `n_features_in_` when fit.
  1412. .. versionadded:: 0.24
  1413. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  1414. Names of features seen during :term:`fit`. Only defined if
  1415. `best_estimator_` is defined (see the documentation for the `refit`
  1416. parameter for more details) and that `best_estimator_` exposes
  1417. `feature_names_in_` when fit.
  1418. .. versionadded:: 1.0
  1419. See Also
  1420. --------
  1421. GridSearchCV : Does exhaustive search over a grid of parameters.
  1422. ParameterSampler : A generator over parameter settings, constructed from
  1423. param_distributions.
  1424. Notes
  1425. -----
  1426. The parameters selected are those that maximize the score of the held-out
  1427. data, according to the scoring parameter.
  1428. If `n_jobs` was set to a value higher than one, the data is copied for each
  1429. parameter setting(and not `n_jobs` times). This is done for efficiency
  1430. reasons if individual jobs take very little time, but may raise errors if
  1431. the dataset is large and not enough memory is available. A workaround in
  1432. this case is to set `pre_dispatch`. Then, the memory is copied only
  1433. `pre_dispatch` many times. A reasonable value for `pre_dispatch` is `2 *
  1434. n_jobs`.
  1435. Examples
  1436. --------
  1437. >>> from sklearn.datasets import load_iris
  1438. >>> from sklearn.linear_model import LogisticRegression
  1439. >>> from sklearn.model_selection import RandomizedSearchCV
  1440. >>> from scipy.stats import uniform
  1441. >>> iris = load_iris()
  1442. >>> logistic = LogisticRegression(solver='saga', tol=1e-2, max_iter=200,
  1443. ... random_state=0)
  1444. >>> distributions = dict(C=uniform(loc=0, scale=4),
  1445. ... penalty=['l2', 'l1'])
  1446. >>> clf = RandomizedSearchCV(logistic, distributions, random_state=0)
  1447. >>> search = clf.fit(iris.data, iris.target)
  1448. >>> search.best_params_
  1449. {'C': 2..., 'penalty': 'l1'}
  1450. """
  1451. _required_parameters = ["estimator", "param_distributions"]
  1452. _parameter_constraints: dict = {
  1453. **BaseSearchCV._parameter_constraints,
  1454. "param_distributions": [dict, list],
  1455. "n_iter": [Interval(numbers.Integral, 1, None, closed="left")],
  1456. "random_state": ["random_state"],
  1457. }
  1458. def __init__(
  1459. self,
  1460. estimator,
  1461. param_distributions,
  1462. *,
  1463. n_iter=10,
  1464. scoring=None,
  1465. n_jobs=None,
  1466. refit=True,
  1467. cv=None,
  1468. verbose=0,
  1469. pre_dispatch="2*n_jobs",
  1470. random_state=None,
  1471. error_score=np.nan,
  1472. return_train_score=False,
  1473. ):
  1474. self.param_distributions = param_distributions
  1475. self.n_iter = n_iter
  1476. self.random_state = random_state
  1477. super().__init__(
  1478. estimator=estimator,
  1479. scoring=scoring,
  1480. n_jobs=n_jobs,
  1481. refit=refit,
  1482. cv=cv,
  1483. verbose=verbose,
  1484. pre_dispatch=pre_dispatch,
  1485. error_score=error_score,
  1486. return_train_score=return_train_score,
  1487. )
  1488. def _run_search(self, evaluate_candidates):
  1489. """Search n_iter candidates from param_distributions"""
  1490. evaluate_candidates(
  1491. ParameterSampler(
  1492. self.param_distributions, self.n_iter, random_state=self.random_state
  1493. )
  1494. )