| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- """Base class for ensemble-based estimators."""
- # Authors: Gilles Louppe
- # License: BSD 3 clause
- import warnings
- from abc import ABCMeta, abstractmethod
- from typing import List
- import numpy as np
- from joblib import effective_n_jobs
- from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor
- from ..utils import Bunch, _print_elapsed_time, check_random_state, deprecated
- from ..utils.metaestimators import _BaseComposition
- def _fit_single_estimator(
- estimator, X, y, sample_weight=None, message_clsname=None, message=None
- ):
- """Private function used to fit an estimator within a job."""
- if sample_weight is not None:
- try:
- with _print_elapsed_time(message_clsname, message):
- estimator.fit(X, y, sample_weight=sample_weight)
- except TypeError as exc:
- if "unexpected keyword argument 'sample_weight'" in str(exc):
- raise TypeError(
- "Underlying estimator {} does not support sample weights.".format(
- estimator.__class__.__name__
- )
- ) from exc
- raise
- else:
- with _print_elapsed_time(message_clsname, message):
- estimator.fit(X, y)
- return estimator
- def _set_random_states(estimator, random_state=None):
- """Set fixed random_state parameters for an estimator.
- Finds all parameters ending ``random_state`` and sets them to integers
- derived from ``random_state``.
- Parameters
- ----------
- estimator : estimator supporting get/set_params
- Estimator with potential randomness managed by random_state
- parameters.
- random_state : int, RandomState instance or None, default=None
- Pseudo-random number generator to control the generation of the random
- integers. Pass an int for reproducible output across multiple function
- calls.
- See :term:`Glossary <random_state>`.
- Notes
- -----
- This does not necessarily set *all* ``random_state`` attributes that
- control an estimator's randomness, only those accessible through
- ``estimator.get_params()``. ``random_state``s not controlled include
- those belonging to:
- * cross-validation splitters
- * ``scipy.stats`` rvs
- """
- random_state = check_random_state(random_state)
- to_set = {}
- for key in sorted(estimator.get_params(deep=True)):
- if key == "random_state" or key.endswith("__random_state"):
- to_set[key] = random_state.randint(np.iinfo(np.int32).max)
- if to_set:
- estimator.set_params(**to_set)
- class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
- """Base class for all ensemble classes.
- Warning: This class should not be used directly. Use derived classes
- instead.
- Parameters
- ----------
- estimator : object
- The base estimator from which the ensemble is built.
- n_estimators : int, default=10
- The number of estimators in the ensemble.
- estimator_params : list of str, default=tuple()
- The list of attributes to use as parameters when instantiating a
- new base estimator. If none are given, default parameters are used.
- base_estimator : object, default="deprecated"
- Use `estimator` instead.
- .. deprecated:: 1.2
- `base_estimator` is deprecated and will be removed in 1.4.
- Use `estimator` instead.
- Attributes
- ----------
- estimator_ : estimator
- The base estimator from which the ensemble is grown.
- base_estimator_ : estimator
- The base estimator from which the ensemble is grown.
- .. deprecated:: 1.2
- `base_estimator_` is deprecated and will be removed in 1.4.
- Use `estimator_` instead.
- estimators_ : list of estimators
- The collection of fitted base estimators.
- """
- # overwrite _required_parameters from MetaEstimatorMixin
- _required_parameters: List[str] = []
- @abstractmethod
- def __init__(
- self,
- estimator=None,
- *,
- n_estimators=10,
- estimator_params=tuple(),
- base_estimator="deprecated",
- ):
- # Set parameters
- self.estimator = estimator
- self.n_estimators = n_estimators
- self.estimator_params = estimator_params
- self.base_estimator = base_estimator
- # Don't instantiate estimators now! Parameters of base_estimator might
- # still change. Eg., when grid-searching with the nested object syntax.
- # self.estimators_ needs to be filled by the derived classes in fit.
- def _validate_estimator(self, default=None):
- """Check the base estimator.
- Sets the `estimator_` attributes.
- """
- if self.estimator is not None and (
- self.base_estimator not in [None, "deprecated"]
- ):
- raise ValueError(
- "Both `estimator` and `base_estimator` were set. Only set `estimator`."
- )
- if self.estimator is not None:
- self.estimator_ = self.estimator
- elif self.base_estimator != "deprecated":
- warnings.warn(
- (
- "`base_estimator` was renamed to `estimator` in version 1.2 and "
- "will be removed in 1.4."
- ),
- FutureWarning,
- )
- if self.base_estimator is not None:
- self.estimator_ = self.base_estimator
- else:
- self.estimator_ = default
- else:
- self.estimator_ = default
- # TODO(1.4): remove
- # mypy error: Decorated property not supported
- @deprecated( # type: ignore
- "Attribute `base_estimator_` was deprecated in version 1.2 and will be removed "
- "in 1.4. Use `estimator_` instead."
- )
- @property
- def base_estimator_(self):
- """Estimator used to grow the ensemble."""
- return self.estimator_
- def _make_estimator(self, append=True, random_state=None):
- """Make and configure a copy of the `estimator_` attribute.
- Warning: This method should be used to properly instantiate new
- sub-estimators.
- """
- estimator = clone(self.estimator_)
- estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
- if random_state is not None:
- _set_random_states(estimator, random_state)
- if append:
- self.estimators_.append(estimator)
- return estimator
- def __len__(self):
- """Return the number of estimators in the ensemble."""
- return len(self.estimators_)
- def __getitem__(self, index):
- """Return the index'th estimator in the ensemble."""
- return self.estimators_[index]
- def __iter__(self):
- """Return iterator over estimators in the ensemble."""
- return iter(self.estimators_)
- def _partition_estimators(n_estimators, n_jobs):
- """Private function used to partition estimators between jobs."""
- # Compute the number of jobs
- n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
- # Partition estimators between jobs
- n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int)
- n_estimators_per_job[: n_estimators % n_jobs] += 1
- starts = np.cumsum(n_estimators_per_job)
- return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
- class _BaseHeterogeneousEnsemble(
- MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta
- ):
- """Base class for heterogeneous ensemble of learners.
- Parameters
- ----------
- estimators : list of (str, estimator) tuples
- The ensemble of estimators to use in the ensemble. Each element of the
- list is defined as a tuple of string (i.e. name of the estimator) and
- an estimator instance. An estimator can be set to `'drop'` using
- `set_params`.
- Attributes
- ----------
- estimators_ : list of estimators
- The elements of the estimators parameter, having been fitted on the
- training data. If an estimator has been set to `'drop'`, it will not
- appear in `estimators_`.
- """
- _required_parameters = ["estimators"]
- @property
- def named_estimators(self):
- """Dictionary to access any fitted sub-estimators by name.
- Returns
- -------
- :class:`~sklearn.utils.Bunch`
- """
- return Bunch(**dict(self.estimators))
- @abstractmethod
- def __init__(self, estimators):
- self.estimators = estimators
- def _validate_estimators(self):
- if len(self.estimators) == 0:
- raise ValueError(
- "Invalid 'estimators' attribute, 'estimators' should be a "
- "non-empty list of (string, estimator) tuples."
- )
- names, estimators = zip(*self.estimators)
- # defined by MetaEstimatorMixin
- self._validate_names(names)
- has_estimator = any(est != "drop" for est in estimators)
- if not has_estimator:
- raise ValueError(
- "All estimators are dropped. At least one is required "
- "to be an estimator."
- )
- is_estimator_type = is_classifier if is_classifier(self) else is_regressor
- for est in estimators:
- if est != "drop" and not is_estimator_type(est):
- raise ValueError(
- "The estimator {} should be a {}.".format(
- est.__class__.__name__, is_estimator_type.__name__[3:]
- )
- )
- return names, estimators
- def set_params(self, **params):
- """
- Set the parameters of an estimator from the ensemble.
- Valid parameter keys can be listed with `get_params()`. Note that you
- can directly set the parameters of the estimators contained in
- `estimators`.
- Parameters
- ----------
- **params : keyword arguments
- Specific parameters using e.g.
- `set_params(parameter_name=new_value)`. In addition, to setting the
- parameters of the estimator, the individual estimator of the
- estimators can also be set, or can be removed by setting them to
- 'drop'.
- Returns
- -------
- self : object
- Estimator instance.
- """
- super()._set_params("estimators", **params)
- return self
- def get_params(self, deep=True):
- """
- Get the parameters of an estimator from the ensemble.
- Returns the parameters given in the constructor as well as the
- estimators contained within the `estimators` parameter.
- Parameters
- ----------
- deep : bool, default=True
- Setting it to True gets the various estimators and the parameters
- of the estimators as well.
- Returns
- -------
- params : dict
- Parameter and estimator names mapped to their values or parameter
- names mapped to their values.
- """
- return super()._get_params("estimators", deep=deep)
|