_base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """Base class for ensemble-based estimators."""
  2. # Authors: Gilles Louppe
  3. # License: BSD 3 clause
  4. import warnings
  5. from abc import ABCMeta, abstractmethod
  6. from typing import List
  7. import numpy as np
  8. from joblib import effective_n_jobs
  9. from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor
  10. from ..utils import Bunch, _print_elapsed_time, check_random_state, deprecated
  11. from ..utils.metaestimators import _BaseComposition
  12. def _fit_single_estimator(
  13. estimator, X, y, sample_weight=None, message_clsname=None, message=None
  14. ):
  15. """Private function used to fit an estimator within a job."""
  16. if sample_weight is not None:
  17. try:
  18. with _print_elapsed_time(message_clsname, message):
  19. estimator.fit(X, y, sample_weight=sample_weight)
  20. except TypeError as exc:
  21. if "unexpected keyword argument 'sample_weight'" in str(exc):
  22. raise TypeError(
  23. "Underlying estimator {} does not support sample weights.".format(
  24. estimator.__class__.__name__
  25. )
  26. ) from exc
  27. raise
  28. else:
  29. with _print_elapsed_time(message_clsname, message):
  30. estimator.fit(X, y)
  31. return estimator
  32. def _set_random_states(estimator, random_state=None):
  33. """Set fixed random_state parameters for an estimator.
  34. Finds all parameters ending ``random_state`` and sets them to integers
  35. derived from ``random_state``.
  36. Parameters
  37. ----------
  38. estimator : estimator supporting get/set_params
  39. Estimator with potential randomness managed by random_state
  40. parameters.
  41. random_state : int, RandomState instance or None, default=None
  42. Pseudo-random number generator to control the generation of the random
  43. integers. Pass an int for reproducible output across multiple function
  44. calls.
  45. See :term:`Glossary <random_state>`.
  46. Notes
  47. -----
  48. This does not necessarily set *all* ``random_state`` attributes that
  49. control an estimator's randomness, only those accessible through
  50. ``estimator.get_params()``. ``random_state``s not controlled include
  51. those belonging to:
  52. * cross-validation splitters
  53. * ``scipy.stats`` rvs
  54. """
  55. random_state = check_random_state(random_state)
  56. to_set = {}
  57. for key in sorted(estimator.get_params(deep=True)):
  58. if key == "random_state" or key.endswith("__random_state"):
  59. to_set[key] = random_state.randint(np.iinfo(np.int32).max)
  60. if to_set:
  61. estimator.set_params(**to_set)
  62. class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
  63. """Base class for all ensemble classes.
  64. Warning: This class should not be used directly. Use derived classes
  65. instead.
  66. Parameters
  67. ----------
  68. estimator : object
  69. The base estimator from which the ensemble is built.
  70. n_estimators : int, default=10
  71. The number of estimators in the ensemble.
  72. estimator_params : list of str, default=tuple()
  73. The list of attributes to use as parameters when instantiating a
  74. new base estimator. If none are given, default parameters are used.
  75. base_estimator : object, default="deprecated"
  76. Use `estimator` instead.
  77. .. deprecated:: 1.2
  78. `base_estimator` is deprecated and will be removed in 1.4.
  79. Use `estimator` instead.
  80. Attributes
  81. ----------
  82. estimator_ : estimator
  83. The base estimator from which the ensemble is grown.
  84. base_estimator_ : estimator
  85. The base estimator from which the ensemble is grown.
  86. .. deprecated:: 1.2
  87. `base_estimator_` is deprecated and will be removed in 1.4.
  88. Use `estimator_` instead.
  89. estimators_ : list of estimators
  90. The collection of fitted base estimators.
  91. """
  92. # overwrite _required_parameters from MetaEstimatorMixin
  93. _required_parameters: List[str] = []
  94. @abstractmethod
  95. def __init__(
  96. self,
  97. estimator=None,
  98. *,
  99. n_estimators=10,
  100. estimator_params=tuple(),
  101. base_estimator="deprecated",
  102. ):
  103. # Set parameters
  104. self.estimator = estimator
  105. self.n_estimators = n_estimators
  106. self.estimator_params = estimator_params
  107. self.base_estimator = base_estimator
  108. # Don't instantiate estimators now! Parameters of base_estimator might
  109. # still change. Eg., when grid-searching with the nested object syntax.
  110. # self.estimators_ needs to be filled by the derived classes in fit.
  111. def _validate_estimator(self, default=None):
  112. """Check the base estimator.
  113. Sets the `estimator_` attributes.
  114. """
  115. if self.estimator is not None and (
  116. self.base_estimator not in [None, "deprecated"]
  117. ):
  118. raise ValueError(
  119. "Both `estimator` and `base_estimator` were set. Only set `estimator`."
  120. )
  121. if self.estimator is not None:
  122. self.estimator_ = self.estimator
  123. elif self.base_estimator != "deprecated":
  124. warnings.warn(
  125. (
  126. "`base_estimator` was renamed to `estimator` in version 1.2 and "
  127. "will be removed in 1.4."
  128. ),
  129. FutureWarning,
  130. )
  131. if self.base_estimator is not None:
  132. self.estimator_ = self.base_estimator
  133. else:
  134. self.estimator_ = default
  135. else:
  136. self.estimator_ = default
  137. # TODO(1.4): remove
  138. # mypy error: Decorated property not supported
  139. @deprecated( # type: ignore
  140. "Attribute `base_estimator_` was deprecated in version 1.2 and will be removed "
  141. "in 1.4. Use `estimator_` instead."
  142. )
  143. @property
  144. def base_estimator_(self):
  145. """Estimator used to grow the ensemble."""
  146. return self.estimator_
  147. def _make_estimator(self, append=True, random_state=None):
  148. """Make and configure a copy of the `estimator_` attribute.
  149. Warning: This method should be used to properly instantiate new
  150. sub-estimators.
  151. """
  152. estimator = clone(self.estimator_)
  153. estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
  154. if random_state is not None:
  155. _set_random_states(estimator, random_state)
  156. if append:
  157. self.estimators_.append(estimator)
  158. return estimator
  159. def __len__(self):
  160. """Return the number of estimators in the ensemble."""
  161. return len(self.estimators_)
  162. def __getitem__(self, index):
  163. """Return the index'th estimator in the ensemble."""
  164. return self.estimators_[index]
  165. def __iter__(self):
  166. """Return iterator over estimators in the ensemble."""
  167. return iter(self.estimators_)
  168. def _partition_estimators(n_estimators, n_jobs):
  169. """Private function used to partition estimators between jobs."""
  170. # Compute the number of jobs
  171. n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
  172. # Partition estimators between jobs
  173. n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int)
  174. n_estimators_per_job[: n_estimators % n_jobs] += 1
  175. starts = np.cumsum(n_estimators_per_job)
  176. return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
  177. class _BaseHeterogeneousEnsemble(
  178. MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta
  179. ):
  180. """Base class for heterogeneous ensemble of learners.
  181. Parameters
  182. ----------
  183. estimators : list of (str, estimator) tuples
  184. The ensemble of estimators to use in the ensemble. Each element of the
  185. list is defined as a tuple of string (i.e. name of the estimator) and
  186. an estimator instance. An estimator can be set to `'drop'` using
  187. `set_params`.
  188. Attributes
  189. ----------
  190. estimators_ : list of estimators
  191. The elements of the estimators parameter, having been fitted on the
  192. training data. If an estimator has been set to `'drop'`, it will not
  193. appear in `estimators_`.
  194. """
  195. _required_parameters = ["estimators"]
  196. @property
  197. def named_estimators(self):
  198. """Dictionary to access any fitted sub-estimators by name.
  199. Returns
  200. -------
  201. :class:`~sklearn.utils.Bunch`
  202. """
  203. return Bunch(**dict(self.estimators))
  204. @abstractmethod
  205. def __init__(self, estimators):
  206. self.estimators = estimators
  207. def _validate_estimators(self):
  208. if len(self.estimators) == 0:
  209. raise ValueError(
  210. "Invalid 'estimators' attribute, 'estimators' should be a "
  211. "non-empty list of (string, estimator) tuples."
  212. )
  213. names, estimators = zip(*self.estimators)
  214. # defined by MetaEstimatorMixin
  215. self._validate_names(names)
  216. has_estimator = any(est != "drop" for est in estimators)
  217. if not has_estimator:
  218. raise ValueError(
  219. "All estimators are dropped. At least one is required "
  220. "to be an estimator."
  221. )
  222. is_estimator_type = is_classifier if is_classifier(self) else is_regressor
  223. for est in estimators:
  224. if est != "drop" and not is_estimator_type(est):
  225. raise ValueError(
  226. "The estimator {} should be a {}.".format(
  227. est.__class__.__name__, is_estimator_type.__name__[3:]
  228. )
  229. )
  230. return names, estimators
  231. def set_params(self, **params):
  232. """
  233. Set the parameters of an estimator from the ensemble.
  234. Valid parameter keys can be listed with `get_params()`. Note that you
  235. can directly set the parameters of the estimators contained in
  236. `estimators`.
  237. Parameters
  238. ----------
  239. **params : keyword arguments
  240. Specific parameters using e.g.
  241. `set_params(parameter_name=new_value)`. In addition, to setting the
  242. parameters of the estimator, the individual estimator of the
  243. estimators can also be set, or can be removed by setting them to
  244. 'drop'.
  245. Returns
  246. -------
  247. self : object
  248. Estimator instance.
  249. """
  250. super()._set_params("estimators", **params)
  251. return self
  252. def get_params(self, deep=True):
  253. """
  254. Get the parameters of an estimator from the ensemble.
  255. Returns the parameters given in the constructor as well as the
  256. estimators contained within the `estimators` parameter.
  257. Parameters
  258. ----------
  259. deep : bool, default=True
  260. Setting it to True gets the various estimators and the parameters
  261. of the estimators as well.
  262. Returns
  263. -------
  264. params : dict
  265. Parameter and estimator names mapped to their values or parameter
  266. names mapped to their values.
  267. """
  268. return super()._get_params("estimators", deep=deep)