| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """Utilities for meta-estimators"""
- # Author: Joel Nothman
- # Andreas Mueller
- # License: BSD
- from abc import ABCMeta, abstractmethod
- from contextlib import suppress
- from typing import Any, List
- import numpy as np
- from ..base import BaseEstimator
- from ..utils import _safe_indexing
- from ..utils._tags import _safe_tags
- from ._available_if import available_if
- __all__ = ["available_if"]
- class _BaseComposition(BaseEstimator, metaclass=ABCMeta):
- """Handles parameter management for classifiers composed of named estimators."""
- steps: List[Any]
- @abstractmethod
- def __init__(self):
- pass
- def _get_params(self, attr, deep=True):
- out = super().get_params(deep=deep)
- if not deep:
- return out
- estimators = getattr(self, attr)
- try:
- out.update(estimators)
- except (TypeError, ValueError):
- # Ignore TypeError for cases where estimators is not a list of
- # (name, estimator) and ignore ValueError when the list is not
- # formatted correctly. This is to prevent errors when calling
- # `set_params`. `BaseEstimator.set_params` calls `get_params` which
- # can error for invalid values for `estimators`.
- return out
- for name, estimator in estimators:
- if hasattr(estimator, "get_params"):
- for key, value in estimator.get_params(deep=True).items():
- out["%s__%s" % (name, key)] = value
- return out
- def _set_params(self, attr, **params):
- # Ensure strict ordering of parameter setting:
- # 1. All steps
- if attr in params:
- setattr(self, attr, params.pop(attr))
- # 2. Replace items with estimators in params
- items = getattr(self, attr)
- if isinstance(items, list) and items:
- # Get item names used to identify valid names in params
- # `zip` raises a TypeError when `items` does not contains
- # elements of length 2
- with suppress(TypeError):
- item_names, _ = zip(*items)
- for name in list(params.keys()):
- if "__" not in name and name in item_names:
- self._replace_estimator(attr, name, params.pop(name))
- # 3. Step parameters and other initialisation arguments
- super().set_params(**params)
- return self
- def _replace_estimator(self, attr, name, new_val):
- # assumes `name` is a valid estimator name
- new_estimators = list(getattr(self, attr))
- for i, (estimator_name, _) in enumerate(new_estimators):
- if estimator_name == name:
- new_estimators[i] = (name, new_val)
- break
- setattr(self, attr, new_estimators)
- def _validate_names(self, names):
- if len(set(names)) != len(names):
- raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
- invalid_names = set(names).intersection(self.get_params(deep=False))
- if invalid_names:
- raise ValueError(
- "Estimator names conflict with constructor arguments: {0!r}".format(
- sorted(invalid_names)
- )
- )
- invalid_names = [name for name in names if "__" in name]
- if invalid_names:
- raise ValueError(
- "Estimator names must not contain __: got {0!r}".format(invalid_names)
- )
- def _safe_split(estimator, X, y, indices, train_indices=None):
- """Create subset of dataset and properly handle kernels.
- Slice X, y according to indices for cross-validation, but take care of
- precomputed kernel-matrices or pairwise affinities / distances.
- If ``estimator._pairwise is True``, X needs to be square and
- we slice rows and columns. If ``train_indices`` is not None,
- we slice rows using ``indices`` (assumed the test set) and columns
- using ``train_indices``, indicating the training set.
- Labels y will always be indexed only along the first axis.
- Parameters
- ----------
- estimator : object
- Estimator to determine whether we should slice only rows or rows and
- columns.
- X : array-like, sparse matrix or iterable
- Data to be indexed. If ``estimator._pairwise is True``,
- this needs to be a square array-like or sparse matrix.
- y : array-like, sparse matrix or iterable
- Targets to be indexed.
- indices : array of int
- Rows to select from X and y.
- If ``estimator._pairwise is True`` and ``train_indices is None``
- then ``indices`` will also be used to slice columns.
- train_indices : array of int or None, default=None
- If ``estimator._pairwise is True`` and ``train_indices is not None``,
- then ``train_indices`` will be use to slice the columns of X.
- Returns
- -------
- X_subset : array-like, sparse matrix or list
- Indexed data.
- y_subset : array-like, sparse matrix or list
- Indexed targets.
- """
- if _safe_tags(estimator, key="pairwise"):
- if not hasattr(X, "shape"):
- raise ValueError(
- "Precomputed kernels or affinity matrices have "
- "to be passed as arrays or sparse matrices."
- )
- # X is a precomputed square kernel matrix
- if X.shape[0] != X.shape[1]:
- raise ValueError("X should be a square kernel matrix")
- if train_indices is None:
- X_subset = X[np.ix_(indices, indices)]
- else:
- X_subset = X[np.ix_(indices, train_indices)]
- else:
- X_subset = _safe_indexing(X, indices)
- if y is not None:
- y_subset = _safe_indexing(y, indices)
- else:
- y_subset = None
- return X_subset, y_subset
|