metaestimators.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """Utilities for meta-estimators"""
  2. # Author: Joel Nothman
  3. # Andreas Mueller
  4. # License: BSD
  5. from abc import ABCMeta, abstractmethod
  6. from contextlib import suppress
  7. from typing import Any, List
  8. import numpy as np
  9. from ..base import BaseEstimator
  10. from ..utils import _safe_indexing
  11. from ..utils._tags import _safe_tags
  12. from ._available_if import available_if
  13. __all__ = ["available_if"]
  14. class _BaseComposition(BaseEstimator, metaclass=ABCMeta):
  15. """Handles parameter management for classifiers composed of named estimators."""
  16. steps: List[Any]
  17. @abstractmethod
  18. def __init__(self):
  19. pass
  20. def _get_params(self, attr, deep=True):
  21. out = super().get_params(deep=deep)
  22. if not deep:
  23. return out
  24. estimators = getattr(self, attr)
  25. try:
  26. out.update(estimators)
  27. except (TypeError, ValueError):
  28. # Ignore TypeError for cases where estimators is not a list of
  29. # (name, estimator) and ignore ValueError when the list is not
  30. # formatted correctly. This is to prevent errors when calling
  31. # `set_params`. `BaseEstimator.set_params` calls `get_params` which
  32. # can error for invalid values for `estimators`.
  33. return out
  34. for name, estimator in estimators:
  35. if hasattr(estimator, "get_params"):
  36. for key, value in estimator.get_params(deep=True).items():
  37. out["%s__%s" % (name, key)] = value
  38. return out
  39. def _set_params(self, attr, **params):
  40. # Ensure strict ordering of parameter setting:
  41. # 1. All steps
  42. if attr in params:
  43. setattr(self, attr, params.pop(attr))
  44. # 2. Replace items with estimators in params
  45. items = getattr(self, attr)
  46. if isinstance(items, list) and items:
  47. # Get item names used to identify valid names in params
  48. # `zip` raises a TypeError when `items` does not contains
  49. # elements of length 2
  50. with suppress(TypeError):
  51. item_names, _ = zip(*items)
  52. for name in list(params.keys()):
  53. if "__" not in name and name in item_names:
  54. self._replace_estimator(attr, name, params.pop(name))
  55. # 3. Step parameters and other initialisation arguments
  56. super().set_params(**params)
  57. return self
  58. def _replace_estimator(self, attr, name, new_val):
  59. # assumes `name` is a valid estimator name
  60. new_estimators = list(getattr(self, attr))
  61. for i, (estimator_name, _) in enumerate(new_estimators):
  62. if estimator_name == name:
  63. new_estimators[i] = (name, new_val)
  64. break
  65. setattr(self, attr, new_estimators)
  66. def _validate_names(self, names):
  67. if len(set(names)) != len(names):
  68. raise ValueError("Names provided are not unique: {0!r}".format(list(names)))
  69. invalid_names = set(names).intersection(self.get_params(deep=False))
  70. if invalid_names:
  71. raise ValueError(
  72. "Estimator names conflict with constructor arguments: {0!r}".format(
  73. sorted(invalid_names)
  74. )
  75. )
  76. invalid_names = [name for name in names if "__" in name]
  77. if invalid_names:
  78. raise ValueError(
  79. "Estimator names must not contain __: got {0!r}".format(invalid_names)
  80. )
  81. def _safe_split(estimator, X, y, indices, train_indices=None):
  82. """Create subset of dataset and properly handle kernels.
  83. Slice X, y according to indices for cross-validation, but take care of
  84. precomputed kernel-matrices or pairwise affinities / distances.
  85. If ``estimator._pairwise is True``, X needs to be square and
  86. we slice rows and columns. If ``train_indices`` is not None,
  87. we slice rows using ``indices`` (assumed the test set) and columns
  88. using ``train_indices``, indicating the training set.
  89. Labels y will always be indexed only along the first axis.
  90. Parameters
  91. ----------
  92. estimator : object
  93. Estimator to determine whether we should slice only rows or rows and
  94. columns.
  95. X : array-like, sparse matrix or iterable
  96. Data to be indexed. If ``estimator._pairwise is True``,
  97. this needs to be a square array-like or sparse matrix.
  98. y : array-like, sparse matrix or iterable
  99. Targets to be indexed.
  100. indices : array of int
  101. Rows to select from X and y.
  102. If ``estimator._pairwise is True`` and ``train_indices is None``
  103. then ``indices`` will also be used to slice columns.
  104. train_indices : array of int or None, default=None
  105. If ``estimator._pairwise is True`` and ``train_indices is not None``,
  106. then ``train_indices`` will be use to slice the columns of X.
  107. Returns
  108. -------
  109. X_subset : array-like, sparse matrix or list
  110. Indexed data.
  111. y_subset : array-like, sparse matrix or list
  112. Indexed targets.
  113. """
  114. if _safe_tags(estimator, key="pairwise"):
  115. if not hasattr(X, "shape"):
  116. raise ValueError(
  117. "Precomputed kernels or affinity matrices have "
  118. "to be passed as arrays or sparse matrices."
  119. )
  120. # X is a precomputed square kernel matrix
  121. if X.shape[0] != X.shape[1]:
  122. raise ValueError("X should be a square kernel matrix")
  123. if train_indices is None:
  124. X_subset = X[np.ix_(indices, indices)]
  125. else:
  126. X_subset = X[np.ix_(indices, train_indices)]
  127. else:
  128. X_subset = _safe_indexing(X, indices)
  129. if y is not None:
  130. y_subset = _safe_indexing(y, indices)
  131. else:
  132. y_subset = None
  133. return X_subset, y_subset