base.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156
  1. """Base classes for all estimators."""
  2. # Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
  3. # License: BSD 3 clause
  4. import copy
  5. import functools
  6. import inspect
  7. import platform
  8. import re
  9. import warnings
  10. from collections import defaultdict
  11. import numpy as np
  12. from . import __version__
  13. from ._config import config_context, get_config
  14. from .exceptions import InconsistentVersionWarning
  15. from .utils import _IS_32BIT
  16. from .utils._estimator_html_repr import estimator_html_repr
  17. from .utils._metadata_requests import _MetadataRequester
  18. from .utils._param_validation import validate_parameter_constraints
  19. from .utils._set_output import _SetOutputMixin
  20. from .utils._tags import (
  21. _DEFAULT_TAGS,
  22. )
  23. from .utils.validation import (
  24. _check_feature_names_in,
  25. _check_y,
  26. _generate_get_feature_names_out,
  27. _get_feature_names,
  28. _is_fitted,
  29. _num_features,
  30. check_array,
  31. check_is_fitted,
  32. check_X_y,
  33. )
  34. def clone(estimator, *, safe=True):
  35. """Construct a new unfitted estimator with the same parameters.
  36. Clone does a deep copy of the model in an estimator
  37. without actually copying attached data. It returns a new estimator
  38. with the same parameters that has not been fitted on any data.
  39. .. versionchanged:: 1.3
  40. Delegates to `estimator.__sklearn_clone__` if the method exists.
  41. Parameters
  42. ----------
  43. estimator : {list, tuple, set} of estimator instance or a single \
  44. estimator instance
  45. The estimator or group of estimators to be cloned.
  46. safe : bool, default=True
  47. If safe is False, clone will fall back to a deep copy on objects
  48. that are not estimators. Ignored if `estimator.__sklearn_clone__`
  49. exists.
  50. Returns
  51. -------
  52. estimator : object
  53. The deep copy of the input, an estimator if input is an estimator.
  54. Notes
  55. -----
  56. If the estimator's `random_state` parameter is an integer (or if the
  57. estimator doesn't have a `random_state` parameter), an *exact clone* is
  58. returned: the clone and the original estimator will give the exact same
  59. results. Otherwise, *statistical clone* is returned: the clone might
  60. return different results from the original estimator. More details can be
  61. found in :ref:`randomness`.
  62. """
  63. if hasattr(estimator, "__sklearn_clone__") and not inspect.isclass(estimator):
  64. return estimator.__sklearn_clone__()
  65. return _clone_parametrized(estimator, safe=safe)
  66. def _clone_parametrized(estimator, *, safe=True):
  67. """Default implementation of clone. See :func:`sklearn.base.clone` for details."""
  68. estimator_type = type(estimator)
  69. if estimator_type is dict:
  70. return {k: clone(v, safe=safe) for k, v in estimator.items()}
  71. elif estimator_type in (list, tuple, set, frozenset):
  72. return estimator_type([clone(e, safe=safe) for e in estimator])
  73. elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
  74. if not safe:
  75. return copy.deepcopy(estimator)
  76. else:
  77. if isinstance(estimator, type):
  78. raise TypeError(
  79. "Cannot clone object. "
  80. + "You should provide an instance of "
  81. + "scikit-learn estimator instead of a class."
  82. )
  83. else:
  84. raise TypeError(
  85. "Cannot clone object '%s' (type %s): "
  86. "it does not seem to be a scikit-learn "
  87. "estimator as it does not implement a "
  88. "'get_params' method." % (repr(estimator), type(estimator))
  89. )
  90. klass = estimator.__class__
  91. new_object_params = estimator.get_params(deep=False)
  92. for name, param in new_object_params.items():
  93. new_object_params[name] = clone(param, safe=False)
  94. new_object = klass(**new_object_params)
  95. try:
  96. new_object._metadata_request = copy.deepcopy(estimator._metadata_request)
  97. except AttributeError:
  98. pass
  99. params_set = new_object.get_params(deep=False)
  100. # quick sanity check of the parameters of the clone
  101. for name in new_object_params:
  102. param1 = new_object_params[name]
  103. param2 = params_set[name]
  104. if param1 is not param2:
  105. raise RuntimeError(
  106. "Cannot clone object %s, as the constructor "
  107. "either does not set or modifies parameter %s" % (estimator, name)
  108. )
  109. # _sklearn_output_config is used by `set_output` to configure the output
  110. # container of an estimator.
  111. if hasattr(estimator, "_sklearn_output_config"):
  112. new_object._sklearn_output_config = copy.deepcopy(
  113. estimator._sklearn_output_config
  114. )
  115. return new_object
  116. class BaseEstimator(_MetadataRequester):
  117. """Base class for all estimators in scikit-learn.
  118. Notes
  119. -----
  120. All estimators should specify all the parameters that can be set
  121. at the class level in their ``__init__`` as explicit keyword
  122. arguments (no ``*args`` or ``**kwargs``).
  123. """
  124. @classmethod
  125. def _get_param_names(cls):
  126. """Get parameter names for the estimator"""
  127. # fetch the constructor or the original constructor before
  128. # deprecation wrapping if any
  129. init = getattr(cls.__init__, "deprecated_original", cls.__init__)
  130. if init is object.__init__:
  131. # No explicit constructor to introspect
  132. return []
  133. # introspect the constructor arguments to find the model parameters
  134. # to represent
  135. init_signature = inspect.signature(init)
  136. # Consider the constructor parameters excluding 'self'
  137. parameters = [
  138. p
  139. for p in init_signature.parameters.values()
  140. if p.name != "self" and p.kind != p.VAR_KEYWORD
  141. ]
  142. for p in parameters:
  143. if p.kind == p.VAR_POSITIONAL:
  144. raise RuntimeError(
  145. "scikit-learn estimators should always "
  146. "specify their parameters in the signature"
  147. " of their __init__ (no varargs)."
  148. " %s with constructor %s doesn't "
  149. " follow this convention." % (cls, init_signature)
  150. )
  151. # Extract and sort argument names excluding 'self'
  152. return sorted([p.name for p in parameters])
  153. def get_params(self, deep=True):
  154. """
  155. Get parameters for this estimator.
  156. Parameters
  157. ----------
  158. deep : bool, default=True
  159. If True, will return the parameters for this estimator and
  160. contained subobjects that are estimators.
  161. Returns
  162. -------
  163. params : dict
  164. Parameter names mapped to their values.
  165. """
  166. out = dict()
  167. for key in self._get_param_names():
  168. value = getattr(self, key)
  169. if deep and hasattr(value, "get_params") and not isinstance(value, type):
  170. deep_items = value.get_params().items()
  171. out.update((key + "__" + k, val) for k, val in deep_items)
  172. out[key] = value
  173. return out
  174. def set_params(self, **params):
  175. """Set the parameters of this estimator.
  176. The method works on simple estimators as well as on nested objects
  177. (such as :class:`~sklearn.pipeline.Pipeline`). The latter have
  178. parameters of the form ``<component>__<parameter>`` so that it's
  179. possible to update each component of a nested object.
  180. Parameters
  181. ----------
  182. **params : dict
  183. Estimator parameters.
  184. Returns
  185. -------
  186. self : estimator instance
  187. Estimator instance.
  188. """
  189. if not params:
  190. # Simple optimization to gain speed (inspect is slow)
  191. return self
  192. valid_params = self.get_params(deep=True)
  193. nested_params = defaultdict(dict) # grouped by prefix
  194. for key, value in params.items():
  195. key, delim, sub_key = key.partition("__")
  196. if key not in valid_params:
  197. local_valid_params = self._get_param_names()
  198. raise ValueError(
  199. f"Invalid parameter {key!r} for estimator {self}. "
  200. f"Valid parameters are: {local_valid_params!r}."
  201. )
  202. if delim:
  203. nested_params[key][sub_key] = value
  204. else:
  205. setattr(self, key, value)
  206. valid_params[key] = value
  207. for key, sub_params in nested_params.items():
  208. # TODO(1.4): remove specific handling of "base_estimator".
  209. # The "base_estimator" key is special. It was deprecated and
  210. # renamed to "estimator" for several estimators. This means we
  211. # need to translate it here and set sub-parameters on "estimator",
  212. # but only if the user did not explicitly set a value for
  213. # "base_estimator".
  214. if (
  215. key == "base_estimator"
  216. and valid_params[key] == "deprecated"
  217. and self.__module__.startswith("sklearn.")
  218. ):
  219. warnings.warn(
  220. (
  221. f"Parameter 'base_estimator' of {self.__class__.__name__} is"
  222. " deprecated in favor of 'estimator'. See"
  223. f" {self.__class__.__name__}'s docstring for more details."
  224. ),
  225. FutureWarning,
  226. stacklevel=2,
  227. )
  228. key = "estimator"
  229. valid_params[key].set_params(**sub_params)
  230. return self
  231. def __sklearn_clone__(self):
  232. return _clone_parametrized(self)
  233. def __repr__(self, N_CHAR_MAX=700):
  234. # N_CHAR_MAX is the (approximate) maximum number of non-blank
  235. # characters to render. We pass it as an optional parameter to ease
  236. # the tests.
  237. from .utils._pprint import _EstimatorPrettyPrinter
  238. N_MAX_ELEMENTS_TO_SHOW = 30 # number of elements to show in sequences
  239. # use ellipsis for sequences with a lot of elements
  240. pp = _EstimatorPrettyPrinter(
  241. compact=True,
  242. indent=1,
  243. indent_at_name=True,
  244. n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
  245. )
  246. repr_ = pp.pformat(self)
  247. # Use bruteforce ellipsis when there are a lot of non-blank characters
  248. n_nonblank = len("".join(repr_.split()))
  249. if n_nonblank > N_CHAR_MAX:
  250. lim = N_CHAR_MAX // 2 # apprx number of chars to keep on both ends
  251. regex = r"^(\s*\S){%d}" % lim
  252. # The regex '^(\s*\S){%d}' % n
  253. # matches from the start of the string until the nth non-blank
  254. # character:
  255. # - ^ matches the start of string
  256. # - (pattern){n} matches n repetitions of pattern
  257. # - \s*\S matches a non-blank char following zero or more blanks
  258. left_lim = re.match(regex, repr_).end()
  259. right_lim = re.match(regex, repr_[::-1]).end()
  260. if "\n" in repr_[left_lim:-right_lim]:
  261. # The left side and right side aren't on the same line.
  262. # To avoid weird cuts, e.g.:
  263. # categoric...ore',
  264. # we need to start the right side with an appropriate newline
  265. # character so that it renders properly as:
  266. # categoric...
  267. # handle_unknown='ignore',
  268. # so we add [^\n]*\n which matches until the next \n
  269. regex += r"[^\n]*\n"
  270. right_lim = re.match(regex, repr_[::-1]).end()
  271. ellipsis = "..."
  272. if left_lim + len(ellipsis) < len(repr_) - right_lim:
  273. # Only add ellipsis if it results in a shorter repr
  274. repr_ = repr_[:left_lim] + "..." + repr_[-right_lim:]
  275. return repr_
  276. def __getstate__(self):
  277. if getattr(self, "__slots__", None):
  278. raise TypeError(
  279. "You cannot use `__slots__` in objects inheriting from "
  280. "`sklearn.base.BaseEstimator`."
  281. )
  282. try:
  283. state = super().__getstate__()
  284. if state is None:
  285. # For Python 3.11+, empty instance (no `__slots__`,
  286. # and `__dict__`) will return a state equal to `None`.
  287. state = self.__dict__.copy()
  288. except AttributeError:
  289. # Python < 3.11
  290. state = self.__dict__.copy()
  291. if type(self).__module__.startswith("sklearn."):
  292. return dict(state.items(), _sklearn_version=__version__)
  293. else:
  294. return state
  295. def __setstate__(self, state):
  296. if type(self).__module__.startswith("sklearn."):
  297. pickle_version = state.pop("_sklearn_version", "pre-0.18")
  298. if pickle_version != __version__:
  299. warnings.warn(
  300. InconsistentVersionWarning(
  301. estimator_name=self.__class__.__name__,
  302. current_sklearn_version=__version__,
  303. original_sklearn_version=pickle_version,
  304. ),
  305. )
  306. try:
  307. super().__setstate__(state)
  308. except AttributeError:
  309. self.__dict__.update(state)
  310. def _more_tags(self):
  311. return _DEFAULT_TAGS
  312. def _get_tags(self):
  313. collected_tags = {}
  314. for base_class in reversed(inspect.getmro(self.__class__)):
  315. if hasattr(base_class, "_more_tags"):
  316. # need the if because mixins might not have _more_tags
  317. # but might do redundant work in estimators
  318. # (i.e. calling more tags on BaseEstimator multiple times)
  319. more_tags = base_class._more_tags(self)
  320. collected_tags.update(more_tags)
  321. return collected_tags
  322. def _check_n_features(self, X, reset):
  323. """Set the `n_features_in_` attribute, or check against it.
  324. Parameters
  325. ----------
  326. X : {ndarray, sparse matrix} of shape (n_samples, n_features)
  327. The input samples.
  328. reset : bool
  329. If True, the `n_features_in_` attribute is set to `X.shape[1]`.
  330. If False and the attribute exists, then check that it is equal to
  331. `X.shape[1]`. If False and the attribute does *not* exist, then
  332. the check is skipped.
  333. .. note::
  334. It is recommended to call reset=True in `fit` and in the first
  335. call to `partial_fit`. All other methods that validate `X`
  336. should set `reset=False`.
  337. """
  338. try:
  339. n_features = _num_features(X)
  340. except TypeError as e:
  341. if not reset and hasattr(self, "n_features_in_"):
  342. raise ValueError(
  343. "X does not contain any features, but "
  344. f"{self.__class__.__name__} is expecting "
  345. f"{self.n_features_in_} features"
  346. ) from e
  347. # If the number of features is not defined and reset=True,
  348. # then we skip this check
  349. return
  350. if reset:
  351. self.n_features_in_ = n_features
  352. return
  353. if not hasattr(self, "n_features_in_"):
  354. # Skip this check if the expected number of expected input features
  355. # was not recorded by calling fit first. This is typically the case
  356. # for stateless transformers.
  357. return
  358. if n_features != self.n_features_in_:
  359. raise ValueError(
  360. f"X has {n_features} features, but {self.__class__.__name__} "
  361. f"is expecting {self.n_features_in_} features as input."
  362. )
  363. def _check_feature_names(self, X, *, reset):
  364. """Set or check the `feature_names_in_` attribute.
  365. .. versionadded:: 1.0
  366. Parameters
  367. ----------
  368. X : {ndarray, dataframe} of shape (n_samples, n_features)
  369. The input samples.
  370. reset : bool
  371. Whether to reset the `feature_names_in_` attribute.
  372. If False, the input will be checked for consistency with
  373. feature names of data provided when reset was last True.
  374. .. note::
  375. It is recommended to call `reset=True` in `fit` and in the first
  376. call to `partial_fit`. All other methods that validate `X`
  377. should set `reset=False`.
  378. """
  379. if reset:
  380. feature_names_in = _get_feature_names(X)
  381. if feature_names_in is not None:
  382. self.feature_names_in_ = feature_names_in
  383. elif hasattr(self, "feature_names_in_"):
  384. # Delete the attribute when the estimator is fitted on a new dataset
  385. # that has no feature names.
  386. delattr(self, "feature_names_in_")
  387. return
  388. fitted_feature_names = getattr(self, "feature_names_in_", None)
  389. X_feature_names = _get_feature_names(X)
  390. if fitted_feature_names is None and X_feature_names is None:
  391. # no feature names seen in fit and in X
  392. return
  393. if X_feature_names is not None and fitted_feature_names is None:
  394. warnings.warn(
  395. f"X has feature names, but {self.__class__.__name__} was fitted without"
  396. " feature names"
  397. )
  398. return
  399. if X_feature_names is None and fitted_feature_names is not None:
  400. warnings.warn(
  401. "X does not have valid feature names, but"
  402. f" {self.__class__.__name__} was fitted with feature names"
  403. )
  404. return
  405. # validate the feature names against the `feature_names_in_` attribute
  406. if len(fitted_feature_names) != len(X_feature_names) or np.any(
  407. fitted_feature_names != X_feature_names
  408. ):
  409. message = (
  410. "The feature names should match those that were passed during fit.\n"
  411. )
  412. fitted_feature_names_set = set(fitted_feature_names)
  413. X_feature_names_set = set(X_feature_names)
  414. unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set)
  415. missing_names = sorted(fitted_feature_names_set - X_feature_names_set)
  416. def add_names(names):
  417. output = ""
  418. max_n_names = 5
  419. for i, name in enumerate(names):
  420. if i >= max_n_names:
  421. output += "- ...\n"
  422. break
  423. output += f"- {name}\n"
  424. return output
  425. if unexpected_names:
  426. message += "Feature names unseen at fit time:\n"
  427. message += add_names(unexpected_names)
  428. if missing_names:
  429. message += "Feature names seen at fit time, yet now missing:\n"
  430. message += add_names(missing_names)
  431. if not missing_names and not unexpected_names:
  432. message += (
  433. "Feature names must be in the same order as they were in fit.\n"
  434. )
  435. raise ValueError(message)
  436. def _validate_data(
  437. self,
  438. X="no_validation",
  439. y="no_validation",
  440. reset=True,
  441. validate_separately=False,
  442. cast_to_ndarray=True,
  443. **check_params,
  444. ):
  445. """Validate input data and set or check the `n_features_in_` attribute.
  446. Parameters
  447. ----------
  448. X : {array-like, sparse matrix, dataframe} of shape \
  449. (n_samples, n_features), default='no validation'
  450. The input samples.
  451. If `'no_validation'`, no validation is performed on `X`. This is
  452. useful for meta-estimator which can delegate input validation to
  453. their underlying estimator(s). In that case `y` must be passed and
  454. the only accepted `check_params` are `multi_output` and
  455. `y_numeric`.
  456. y : array-like of shape (n_samples,), default='no_validation'
  457. The targets.
  458. - If `None`, `check_array` is called on `X`. If the estimator's
  459. requires_y tag is True, then an error will be raised.
  460. - If `'no_validation'`, `check_array` is called on `X` and the
  461. estimator's requires_y tag is ignored. This is a default
  462. placeholder and is never meant to be explicitly set. In that case
  463. `X` must be passed.
  464. - Otherwise, only `y` with `_check_y` or both `X` and `y` are
  465. checked with either `check_array` or `check_X_y` depending on
  466. `validate_separately`.
  467. reset : bool, default=True
  468. Whether to reset the `n_features_in_` attribute.
  469. If False, the input will be checked for consistency with data
  470. provided when reset was last True.
  471. .. note::
  472. It is recommended to call reset=True in `fit` and in the first
  473. call to `partial_fit`. All other methods that validate `X`
  474. should set `reset=False`.
  475. validate_separately : False or tuple of dicts, default=False
  476. Only used if y is not None.
  477. If False, call validate_X_y(). Else, it must be a tuple of kwargs
  478. to be used for calling check_array() on X and y respectively.
  479. `estimator=self` is automatically added to these dicts to generate
  480. more informative error message in case of invalid input data.
  481. cast_to_ndarray : bool, default=True
  482. Cast `X` and `y` to ndarray with checks in `check_params`. If
  483. `False`, `X` and `y` are unchanged and only `feature_names_in_` and
  484. `n_features_in_` are checked.
  485. **check_params : kwargs
  486. Parameters passed to :func:`sklearn.utils.check_array` or
  487. :func:`sklearn.utils.check_X_y`. Ignored if validate_separately
  488. is not False.
  489. `estimator=self` is automatically added to these params to generate
  490. more informative error message in case of invalid input data.
  491. Returns
  492. -------
  493. out : {ndarray, sparse matrix} or tuple of these
  494. The validated input. A tuple is returned if both `X` and `y` are
  495. validated.
  496. """
  497. self._check_feature_names(X, reset=reset)
  498. if y is None and self._get_tags()["requires_y"]:
  499. raise ValueError(
  500. f"This {self.__class__.__name__} estimator "
  501. "requires y to be passed, but the target y is None."
  502. )
  503. no_val_X = isinstance(X, str) and X == "no_validation"
  504. no_val_y = y is None or isinstance(y, str) and y == "no_validation"
  505. if no_val_X and no_val_y:
  506. raise ValueError("Validation should be done on X, y or both.")
  507. default_check_params = {"estimator": self}
  508. check_params = {**default_check_params, **check_params}
  509. if not cast_to_ndarray:
  510. if not no_val_X and no_val_y:
  511. out = X
  512. elif no_val_X and not no_val_y:
  513. out = y
  514. else:
  515. out = X, y
  516. elif not no_val_X and no_val_y:
  517. out = check_array(X, input_name="X", **check_params)
  518. elif no_val_X and not no_val_y:
  519. out = _check_y(y, **check_params)
  520. else:
  521. if validate_separately:
  522. # We need this because some estimators validate X and y
  523. # separately, and in general, separately calling check_array()
  524. # on X and y isn't equivalent to just calling check_X_y()
  525. # :(
  526. check_X_params, check_y_params = validate_separately
  527. if "estimator" not in check_X_params:
  528. check_X_params = {**default_check_params, **check_X_params}
  529. X = check_array(X, input_name="X", **check_X_params)
  530. if "estimator" not in check_y_params:
  531. check_y_params = {**default_check_params, **check_y_params}
  532. y = check_array(y, input_name="y", **check_y_params)
  533. else:
  534. X, y = check_X_y(X, y, **check_params)
  535. out = X, y
  536. if not no_val_X and check_params.get("ensure_2d", True):
  537. self._check_n_features(X, reset=reset)
  538. return out
  539. def _validate_params(self):
  540. """Validate types and values of constructor parameters
  541. The expected type and values must be defined in the `_parameter_constraints`
  542. class attribute, which is a dictionary `param_name: list of constraints`. See
  543. the docstring of `validate_parameter_constraints` for a description of the
  544. accepted constraints.
  545. """
  546. validate_parameter_constraints(
  547. self._parameter_constraints,
  548. self.get_params(deep=False),
  549. caller_name=self.__class__.__name__,
  550. )
  551. @property
  552. def _repr_html_(self):
  553. """HTML representation of estimator.
  554. This is redundant with the logic of `_repr_mimebundle_`. The latter
  555. should be favorted in the long term, `_repr_html_` is only
  556. implemented for consumers who do not interpret `_repr_mimbundle_`.
  557. """
  558. if get_config()["display"] != "diagram":
  559. raise AttributeError(
  560. "_repr_html_ is only defined when the "
  561. "'display' configuration option is set to "
  562. "'diagram'"
  563. )
  564. return self._repr_html_inner
  565. def _repr_html_inner(self):
  566. """This function is returned by the @property `_repr_html_` to make
  567. `hasattr(estimator, "_repr_html_") return `True` or `False` depending
  568. on `get_config()["display"]`.
  569. """
  570. return estimator_html_repr(self)
  571. def _repr_mimebundle_(self, **kwargs):
  572. """Mime bundle used by jupyter kernels to display estimator"""
  573. output = {"text/plain": repr(self)}
  574. if get_config()["display"] == "diagram":
  575. output["text/html"] = estimator_html_repr(self)
  576. return output
  577. class ClassifierMixin:
  578. """Mixin class for all classifiers in scikit-learn."""
  579. _estimator_type = "classifier"
  580. def score(self, X, y, sample_weight=None):
  581. """
  582. Return the mean accuracy on the given test data and labels.
  583. In multi-label classification, this is the subset accuracy
  584. which is a harsh metric since you require for each sample that
  585. each label set be correctly predicted.
  586. Parameters
  587. ----------
  588. X : array-like of shape (n_samples, n_features)
  589. Test samples.
  590. y : array-like of shape (n_samples,) or (n_samples, n_outputs)
  591. True labels for `X`.
  592. sample_weight : array-like of shape (n_samples,), default=None
  593. Sample weights.
  594. Returns
  595. -------
  596. score : float
  597. Mean accuracy of ``self.predict(X)`` w.r.t. `y`.
  598. """
  599. from .metrics import accuracy_score
  600. return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  601. def _more_tags(self):
  602. return {"requires_y": True}
  603. class RegressorMixin:
  604. """Mixin class for all regression estimators in scikit-learn."""
  605. _estimator_type = "regressor"
  606. def score(self, X, y, sample_weight=None):
  607. """Return the coefficient of determination of the prediction.
  608. The coefficient of determination :math:`R^2` is defined as
  609. :math:`(1 - \\frac{u}{v})`, where :math:`u` is the residual
  610. sum of squares ``((y_true - y_pred)** 2).sum()`` and :math:`v`
  611. is the total sum of squares ``((y_true - y_true.mean()) ** 2).sum()``.
  612. The best possible score is 1.0 and it can be negative (because the
  613. model can be arbitrarily worse). A constant model that always predicts
  614. the expected value of `y`, disregarding the input features, would get
  615. a :math:`R^2` score of 0.0.
  616. Parameters
  617. ----------
  618. X : array-like of shape (n_samples, n_features)
  619. Test samples. For some estimators this may be a precomputed
  620. kernel matrix or a list of generic objects instead with shape
  621. ``(n_samples, n_samples_fitted)``, where ``n_samples_fitted``
  622. is the number of samples used in the fitting for the estimator.
  623. y : array-like of shape (n_samples,) or (n_samples, n_outputs)
  624. True values for `X`.
  625. sample_weight : array-like of shape (n_samples,), default=None
  626. Sample weights.
  627. Returns
  628. -------
  629. score : float
  630. :math:`R^2` of ``self.predict(X)`` w.r.t. `y`.
  631. Notes
  632. -----
  633. The :math:`R^2` score used when calling ``score`` on a regressor uses
  634. ``multioutput='uniform_average'`` from version 0.23 to keep consistent
  635. with default value of :func:`~sklearn.metrics.r2_score`.
  636. This influences the ``score`` method of all the multioutput
  637. regressors (except for
  638. :class:`~sklearn.multioutput.MultiOutputRegressor`).
  639. """
  640. from .metrics import r2_score
  641. y_pred = self.predict(X)
  642. return r2_score(y, y_pred, sample_weight=sample_weight)
  643. def _more_tags(self):
  644. return {"requires_y": True}
  645. class ClusterMixin:
  646. """Mixin class for all cluster estimators in scikit-learn."""
  647. _estimator_type = "clusterer"
  648. def fit_predict(self, X, y=None):
  649. """
  650. Perform clustering on `X` and returns cluster labels.
  651. Parameters
  652. ----------
  653. X : array-like of shape (n_samples, n_features)
  654. Input data.
  655. y : Ignored
  656. Not used, present for API consistency by convention.
  657. Returns
  658. -------
  659. labels : ndarray of shape (n_samples,), dtype=np.int64
  660. Cluster labels.
  661. """
  662. # non-optimized default implementation; override when a better
  663. # method is possible for a given clustering algorithm
  664. self.fit(X)
  665. return self.labels_
  666. def _more_tags(self):
  667. return {"preserves_dtype": []}
  668. class BiclusterMixin:
  669. """Mixin class for all bicluster estimators in scikit-learn."""
  670. @property
  671. def biclusters_(self):
  672. """Convenient way to get row and column indicators together.
  673. Returns the ``rows_`` and ``columns_`` members.
  674. """
  675. return self.rows_, self.columns_
  676. def get_indices(self, i):
  677. """Row and column indices of the `i`'th bicluster.
  678. Only works if ``rows_`` and ``columns_`` attributes exist.
  679. Parameters
  680. ----------
  681. i : int
  682. The index of the cluster.
  683. Returns
  684. -------
  685. row_ind : ndarray, dtype=np.intp
  686. Indices of rows in the dataset that belong to the bicluster.
  687. col_ind : ndarray, dtype=np.intp
  688. Indices of columns in the dataset that belong to the bicluster.
  689. """
  690. rows = self.rows_[i]
  691. columns = self.columns_[i]
  692. return np.nonzero(rows)[0], np.nonzero(columns)[0]
  693. def get_shape(self, i):
  694. """Shape of the `i`'th bicluster.
  695. Parameters
  696. ----------
  697. i : int
  698. The index of the cluster.
  699. Returns
  700. -------
  701. n_rows : int
  702. Number of rows in the bicluster.
  703. n_cols : int
  704. Number of columns in the bicluster.
  705. """
  706. indices = self.get_indices(i)
  707. return tuple(len(i) for i in indices)
  708. def get_submatrix(self, i, data):
  709. """Return the submatrix corresponding to bicluster `i`.
  710. Parameters
  711. ----------
  712. i : int
  713. The index of the cluster.
  714. data : array-like of shape (n_samples, n_features)
  715. The data.
  716. Returns
  717. -------
  718. submatrix : ndarray of shape (n_rows, n_cols)
  719. The submatrix corresponding to bicluster `i`.
  720. Notes
  721. -----
  722. Works with sparse matrices. Only works if ``rows_`` and
  723. ``columns_`` attributes exist.
  724. """
  725. from .utils.validation import check_array
  726. data = check_array(data, accept_sparse="csr")
  727. row_ind, col_ind = self.get_indices(i)
  728. return data[row_ind[:, np.newaxis], col_ind]
  729. class TransformerMixin(_SetOutputMixin):
  730. """Mixin class for all transformers in scikit-learn.
  731. If :term:`get_feature_names_out` is defined, then :class:`BaseEstimator` will
  732. automatically wrap `transform` and `fit_transform` to follow the `set_output`
  733. API. See the :ref:`developer_api_set_output` for details.
  734. :class:`OneToOneFeatureMixin` and
  735. :class:`ClassNamePrefixFeaturesOutMixin` are helpful mixins for
  736. defining :term:`get_feature_names_out`.
  737. """
  738. def fit_transform(self, X, y=None, **fit_params):
  739. """
  740. Fit to data, then transform it.
  741. Fits transformer to `X` and `y` with optional parameters `fit_params`
  742. and returns a transformed version of `X`.
  743. Parameters
  744. ----------
  745. X : array-like of shape (n_samples, n_features)
  746. Input samples.
  747. y : array-like of shape (n_samples,) or (n_samples, n_outputs), \
  748. default=None
  749. Target values (None for unsupervised transformations).
  750. **fit_params : dict
  751. Additional fit parameters.
  752. Returns
  753. -------
  754. X_new : ndarray array of shape (n_samples, n_features_new)
  755. Transformed array.
  756. """
  757. # non-optimized default implementation; override when a better
  758. # method is possible for a given clustering algorithm
  759. if y is None:
  760. # fit method of arity 1 (unsupervised transformation)
  761. return self.fit(X, **fit_params).transform(X)
  762. else:
  763. # fit method of arity 2 (supervised transformation)
  764. return self.fit(X, y, **fit_params).transform(X)
  765. class OneToOneFeatureMixin:
  766. """Provides `get_feature_names_out` for simple transformers.
  767. This mixin assumes there's a 1-to-1 correspondence between input features
  768. and output features, such as :class:`~sklearn.preprocessing.StandardScaler`.
  769. """
  770. def get_feature_names_out(self, input_features=None):
  771. """Get output feature names for transformation.
  772. Parameters
  773. ----------
  774. input_features : array-like of str or None, default=None
  775. Input features.
  776. - If `input_features` is `None`, then `feature_names_in_` is
  777. used as feature names in. If `feature_names_in_` is not defined,
  778. then the following input feature names are generated:
  779. `["x0", "x1", ..., "x(n_features_in_ - 1)"]`.
  780. - If `input_features` is an array-like, then `input_features` must
  781. match `feature_names_in_` if `feature_names_in_` is defined.
  782. Returns
  783. -------
  784. feature_names_out : ndarray of str objects
  785. Same as input features.
  786. """
  787. check_is_fitted(self, "n_features_in_")
  788. return _check_feature_names_in(self, input_features)
  789. class ClassNamePrefixFeaturesOutMixin:
  790. """Mixin class for transformers that generate their own names by prefixing.
  791. This mixin is useful when the transformer needs to generate its own feature
  792. names out, such as :class:`~sklearn.decomposition.PCA`. For example, if
  793. :class:`~sklearn.decomposition.PCA` outputs 3 features, then the generated feature
  794. names out are: `["pca0", "pca1", "pca2"]`.
  795. This mixin assumes that a `_n_features_out` attribute is defined when the
  796. transformer is fitted. `_n_features_out` is the number of output features
  797. that the transformer will return in `transform` of `fit_transform`.
  798. """
  799. def get_feature_names_out(self, input_features=None):
  800. """Get output feature names for transformation.
  801. The feature names out will prefixed by the lowercased class name. For
  802. example, if the transformer outputs 3 features, then the feature names
  803. out are: `["class_name0", "class_name1", "class_name2"]`.
  804. Parameters
  805. ----------
  806. input_features : array-like of str or None, default=None
  807. Only used to validate feature names with the names seen in `fit`.
  808. Returns
  809. -------
  810. feature_names_out : ndarray of str objects
  811. Transformed feature names.
  812. """
  813. check_is_fitted(self, "_n_features_out")
  814. return _generate_get_feature_names_out(
  815. self, self._n_features_out, input_features=input_features
  816. )
  817. class DensityMixin:
  818. """Mixin class for all density estimators in scikit-learn."""
  819. _estimator_type = "DensityEstimator"
  820. def score(self, X, y=None):
  821. """Return the score of the model on the data `X`.
  822. Parameters
  823. ----------
  824. X : array-like of shape (n_samples, n_features)
  825. Test samples.
  826. y : Ignored
  827. Not used, present for API consistency by convention.
  828. Returns
  829. -------
  830. score : float
  831. """
  832. pass
  833. class OutlierMixin:
  834. """Mixin class for all outlier detection estimators in scikit-learn."""
  835. _estimator_type = "outlier_detector"
  836. def fit_predict(self, X, y=None):
  837. """Perform fit on X and returns labels for X.
  838. Returns -1 for outliers and 1 for inliers.
  839. Parameters
  840. ----------
  841. X : {array-like, sparse matrix} of shape (n_samples, n_features)
  842. The input samples.
  843. y : Ignored
  844. Not used, present for API consistency by convention.
  845. Returns
  846. -------
  847. y : ndarray of shape (n_samples,)
  848. 1 for inliers, -1 for outliers.
  849. """
  850. # override for transductive outlier detectors like LocalOulierFactor
  851. return self.fit(X).predict(X)
  852. class MetaEstimatorMixin:
  853. _required_parameters = ["estimator"]
  854. """Mixin class for all meta estimators in scikit-learn."""
  855. class MultiOutputMixin:
  856. """Mixin to mark estimators that support multioutput."""
  857. def _more_tags(self):
  858. return {"multioutput": True}
  859. class _UnstableArchMixin:
  860. """Mark estimators that are non-determinstic on 32bit or PowerPC"""
  861. def _more_tags(self):
  862. return {
  863. "non_deterministic": _IS_32BIT or platform.machine().startswith(
  864. ("ppc", "powerpc")
  865. )
  866. }
  867. def is_classifier(estimator):
  868. """Return True if the given estimator is (probably) a classifier.
  869. Parameters
  870. ----------
  871. estimator : object
  872. Estimator object to test.
  873. Returns
  874. -------
  875. out : bool
  876. True if estimator is a classifier and False otherwise.
  877. """
  878. return getattr(estimator, "_estimator_type", None) == "classifier"
  879. def is_regressor(estimator):
  880. """Return True if the given estimator is (probably) a regressor.
  881. Parameters
  882. ----------
  883. estimator : estimator instance
  884. Estimator object to test.
  885. Returns
  886. -------
  887. out : bool
  888. True if estimator is a regressor and False otherwise.
  889. """
  890. return getattr(estimator, "_estimator_type", None) == "regressor"
  891. def is_outlier_detector(estimator):
  892. """Return True if the given estimator is (probably) an outlier detector.
  893. Parameters
  894. ----------
  895. estimator : estimator instance
  896. Estimator object to test.
  897. Returns
  898. -------
  899. out : bool
  900. True if estimator is an outlier detector and False otherwise.
  901. """
  902. return getattr(estimator, "_estimator_type", None) == "outlier_detector"
  903. def _fit_context(*, prefer_skip_nested_validation):
  904. """Decorator to run the fit methods of estimators within context managers.
  905. Parameters
  906. ----------
  907. prefer_skip_nested_validation : bool
  908. If True, the validation of parameters of inner estimators or functions
  909. called during fit will be skipped.
  910. This is useful to avoid validating many times the parameters passed by the
  911. user from the public facing API. It's also useful to avoid validating
  912. parameters that we pass internally to inner functions that are guaranteed to
  913. be valid by the test suite.
  914. It should be set to True for most estimators, except for those that receive
  915. non-validated objects as parameters, such as meta-estimators that are given
  916. estimator objects.
  917. Returns
  918. -------
  919. decorated_fit : method
  920. The decorated fit method.
  921. """
  922. def decorator(fit_method):
  923. @functools.wraps(fit_method)
  924. def wrapper(estimator, *args, **kwargs):
  925. global_skip_validation = get_config()["skip_parameter_validation"]
  926. # we don't want to validate again for each call to partial_fit
  927. partial_fit_and_fitted = (
  928. fit_method.__name__ == "partial_fit" and _is_fitted(estimator)
  929. )
  930. if not global_skip_validation and not partial_fit_and_fitted:
  931. estimator._validate_params()
  932. with config_context(
  933. skip_parameter_validation=(
  934. prefer_skip_nested_validation or global_skip_validation
  935. )
  936. ):
  937. return fit_method(estimator, *args, **kwargs)
  938. return wrapper
  939. return decorator