_from_model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena
  2. # License: BSD 3 clause
  3. from copy import deepcopy
  4. from numbers import Integral, Real
  5. import numpy as np
  6. from ..base import BaseEstimator, MetaEstimatorMixin, _fit_context, clone
  7. from ..exceptions import NotFittedError
  8. from ..utils._param_validation import HasMethods, Interval, Options
  9. from ..utils._tags import _safe_tags
  10. from ..utils.metaestimators import available_if
  11. from ..utils.validation import _num_features, check_is_fitted, check_scalar
  12. from ._base import SelectorMixin, _get_feature_importances
  13. def _calculate_threshold(estimator, importances, threshold):
  14. """Interpret the threshold value"""
  15. if threshold is None:
  16. # determine default from estimator
  17. est_name = estimator.__class__.__name__
  18. is_l1_penalized = hasattr(estimator, "penalty") and estimator.penalty == "l1"
  19. is_lasso = "Lasso" in est_name
  20. is_elasticnet_l1_penalized = "ElasticNet" in est_name and (
  21. (hasattr(estimator, "l1_ratio_") and np.isclose(estimator.l1_ratio_, 1.0))
  22. or (hasattr(estimator, "l1_ratio") and np.isclose(estimator.l1_ratio, 1.0))
  23. )
  24. if is_l1_penalized or is_lasso or is_elasticnet_l1_penalized:
  25. # the natural default threshold is 0 when l1 penalty was used
  26. threshold = 1e-5
  27. else:
  28. threshold = "mean"
  29. if isinstance(threshold, str):
  30. if "*" in threshold:
  31. scale, reference = threshold.split("*")
  32. scale = float(scale.strip())
  33. reference = reference.strip()
  34. if reference == "median":
  35. reference = np.median(importances)
  36. elif reference == "mean":
  37. reference = np.mean(importances)
  38. else:
  39. raise ValueError("Unknown reference: " + reference)
  40. threshold = scale * reference
  41. elif threshold == "median":
  42. threshold = np.median(importances)
  43. elif threshold == "mean":
  44. threshold = np.mean(importances)
  45. else:
  46. raise ValueError(
  47. "Expected threshold='mean' or threshold='median' got %s" % threshold
  48. )
  49. else:
  50. threshold = float(threshold)
  51. return threshold
  52. def _estimator_has(attr):
  53. """Check if we can delegate a method to the underlying estimator.
  54. First, we check the fitted estimator if available, otherwise we
  55. check the unfitted estimator.
  56. """
  57. return lambda self: (
  58. hasattr(self.estimator_, attr)
  59. if hasattr(self, "estimator_")
  60. else hasattr(self.estimator, attr)
  61. )
  62. class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
  63. """Meta-transformer for selecting features based on importance weights.
  64. .. versionadded:: 0.17
  65. Read more in the :ref:`User Guide <select_from_model>`.
  66. Parameters
  67. ----------
  68. estimator : object
  69. The base estimator from which the transformer is built.
  70. This can be both a fitted (if ``prefit`` is set to True)
  71. or a non-fitted estimator. The estimator should have a
  72. ``feature_importances_`` or ``coef_`` attribute after fitting.
  73. Otherwise, the ``importance_getter`` parameter should be used.
  74. threshold : str or float, default=None
  75. The threshold value to use for feature selection. Features whose
  76. absolute importance value is greater or equal are kept while the others
  77. are discarded. If "median" (resp. "mean"), then the ``threshold`` value
  78. is the median (resp. the mean) of the feature importances. A scaling
  79. factor (e.g., "1.25*mean") may also be used. If None and if the
  80. estimator has a parameter penalty set to l1, either explicitly
  81. or implicitly (e.g, Lasso), the threshold used is 1e-5.
  82. Otherwise, "mean" is used by default.
  83. prefit : bool, default=False
  84. Whether a prefit model is expected to be passed into the constructor
  85. directly or not.
  86. If `True`, `estimator` must be a fitted estimator.
  87. If `False`, `estimator` is fitted and updated by calling
  88. `fit` and `partial_fit`, respectively.
  89. norm_order : non-zero int, inf, -inf, default=1
  90. Order of the norm used to filter the vectors of coefficients below
  91. ``threshold`` in the case where the ``coef_`` attribute of the
  92. estimator is of dimension 2.
  93. max_features : int, callable, default=None
  94. The maximum number of features to select.
  95. - If an integer, then it specifies the maximum number of features to
  96. allow.
  97. - If a callable, then it specifies how to calculate the maximum number of
  98. features allowed by using the output of `max_features(X)`.
  99. - If `None`, then all features are kept.
  100. To only select based on ``max_features``, set ``threshold=-np.inf``.
  101. .. versionadded:: 0.20
  102. .. versionchanged:: 1.1
  103. `max_features` accepts a callable.
  104. importance_getter : str or callable, default='auto'
  105. If 'auto', uses the feature importance either through a ``coef_``
  106. attribute or ``feature_importances_`` attribute of estimator.
  107. Also accepts a string that specifies an attribute name/path
  108. for extracting feature importance (implemented with `attrgetter`).
  109. For example, give `regressor_.coef_` in case of
  110. :class:`~sklearn.compose.TransformedTargetRegressor` or
  111. `named_steps.clf.feature_importances_` in case of
  112. :class:`~sklearn.pipeline.Pipeline` with its last step named `clf`.
  113. If `callable`, overrides the default feature importance getter.
  114. The callable is passed with the fitted estimator and it should
  115. return importance for each feature.
  116. .. versionadded:: 0.24
  117. Attributes
  118. ----------
  119. estimator_ : estimator
  120. The base estimator from which the transformer is built. This attribute
  121. exist only when `fit` has been called.
  122. - If `prefit=True`, it is a deep copy of `estimator`.
  123. - If `prefit=False`, it is a clone of `estimator` and fit on the data
  124. passed to `fit` or `partial_fit`.
  125. n_features_in_ : int
  126. Number of features seen during :term:`fit`. Only defined if the
  127. underlying estimator exposes such an attribute when fit.
  128. .. versionadded:: 0.24
  129. max_features_ : int
  130. Maximum number of features calculated during :term:`fit`. Only defined
  131. if the ``max_features`` is not `None`.
  132. - If `max_features` is an `int`, then `max_features_ = max_features`.
  133. - If `max_features` is a callable, then `max_features_ = max_features(X)`.
  134. .. versionadded:: 1.1
  135. feature_names_in_ : ndarray of shape (`n_features_in_`,)
  136. Names of features seen during :term:`fit`. Defined only when `X`
  137. has feature names that are all strings.
  138. .. versionadded:: 1.0
  139. threshold_ : float
  140. The threshold value used for feature selection.
  141. See Also
  142. --------
  143. RFE : Recursive feature elimination based on importance weights.
  144. RFECV : Recursive feature elimination with built-in cross-validated
  145. selection of the best number of features.
  146. SequentialFeatureSelector : Sequential cross-validation based feature
  147. selection. Does not rely on importance weights.
  148. Notes
  149. -----
  150. Allows NaN/Inf in the input if the underlying estimator does as well.
  151. Examples
  152. --------
  153. >>> from sklearn.feature_selection import SelectFromModel
  154. >>> from sklearn.linear_model import LogisticRegression
  155. >>> X = [[ 0.87, -1.34, 0.31 ],
  156. ... [-2.79, -0.02, -0.85 ],
  157. ... [-1.34, -0.48, -2.55 ],
  158. ... [ 1.92, 1.48, 0.65 ]]
  159. >>> y = [0, 1, 0, 1]
  160. >>> selector = SelectFromModel(estimator=LogisticRegression()).fit(X, y)
  161. >>> selector.estimator_.coef_
  162. array([[-0.3252302 , 0.83462377, 0.49750423]])
  163. >>> selector.threshold_
  164. 0.55245...
  165. >>> selector.get_support()
  166. array([False, True, False])
  167. >>> selector.transform(X)
  168. array([[-1.34],
  169. [-0.02],
  170. [-0.48],
  171. [ 1.48]])
  172. Using a callable to create a selector that can use no more than half
  173. of the input features.
  174. >>> def half_callable(X):
  175. ... return round(len(X[0]) / 2)
  176. >>> half_selector = SelectFromModel(estimator=LogisticRegression(),
  177. ... max_features=half_callable)
  178. >>> _ = half_selector.fit(X, y)
  179. >>> half_selector.max_features_
  180. 2
  181. """
  182. _parameter_constraints: dict = {
  183. "estimator": [HasMethods("fit")],
  184. "threshold": [Interval(Real, None, None, closed="both"), str, None],
  185. "prefit": ["boolean"],
  186. "norm_order": [
  187. Interval(Integral, None, -1, closed="right"),
  188. Interval(Integral, 1, None, closed="left"),
  189. Options(Real, {np.inf, -np.inf}),
  190. ],
  191. "max_features": [Interval(Integral, 0, None, closed="left"), callable, None],
  192. "importance_getter": [str, callable],
  193. }
  194. def __init__(
  195. self,
  196. estimator,
  197. *,
  198. threshold=None,
  199. prefit=False,
  200. norm_order=1,
  201. max_features=None,
  202. importance_getter="auto",
  203. ):
  204. self.estimator = estimator
  205. self.threshold = threshold
  206. self.prefit = prefit
  207. self.importance_getter = importance_getter
  208. self.norm_order = norm_order
  209. self.max_features = max_features
  210. def _get_support_mask(self):
  211. estimator = getattr(self, "estimator_", self.estimator)
  212. max_features = getattr(self, "max_features_", self.max_features)
  213. if self.prefit:
  214. try:
  215. check_is_fitted(self.estimator)
  216. except NotFittedError as exc:
  217. raise NotFittedError(
  218. "When `prefit=True`, `estimator` is expected to be a fitted "
  219. "estimator."
  220. ) from exc
  221. if callable(max_features):
  222. # This branch is executed when `transform` is called directly and thus
  223. # `max_features_` is not set and we fallback using `self.max_features`
  224. # that is not validated
  225. raise NotFittedError(
  226. "When `prefit=True` and `max_features` is a callable, call `fit` "
  227. "before calling `transform`."
  228. )
  229. elif max_features is not None and not isinstance(max_features, Integral):
  230. raise ValueError(
  231. f"`max_features` must be an integer. Got `max_features={max_features}` "
  232. "instead."
  233. )
  234. scores = _get_feature_importances(
  235. estimator=estimator,
  236. getter=self.importance_getter,
  237. transform_func="norm",
  238. norm_order=self.norm_order,
  239. )
  240. threshold = _calculate_threshold(estimator, scores, self.threshold)
  241. if self.max_features is not None:
  242. mask = np.zeros_like(scores, dtype=bool)
  243. candidate_indices = np.argsort(-scores, kind="mergesort")[:max_features]
  244. mask[candidate_indices] = True
  245. else:
  246. mask = np.ones_like(scores, dtype=bool)
  247. mask[scores < threshold] = False
  248. return mask
  249. def _check_max_features(self, X):
  250. if self.max_features is not None:
  251. n_features = _num_features(X)
  252. if callable(self.max_features):
  253. max_features = self.max_features(X)
  254. else: # int
  255. max_features = self.max_features
  256. check_scalar(
  257. max_features,
  258. "max_features",
  259. Integral,
  260. min_val=0,
  261. max_val=n_features,
  262. )
  263. self.max_features_ = max_features
  264. @_fit_context(
  265. # SelectFromModel.estimator is not validated yet
  266. prefer_skip_nested_validation=False
  267. )
  268. def fit(self, X, y=None, **fit_params):
  269. """Fit the SelectFromModel meta-transformer.
  270. Parameters
  271. ----------
  272. X : array-like of shape (n_samples, n_features)
  273. The training input samples.
  274. y : array-like of shape (n_samples,), default=None
  275. The target values (integers that correspond to classes in
  276. classification, real numbers in regression).
  277. **fit_params : dict
  278. Other estimator specific parameters.
  279. Returns
  280. -------
  281. self : object
  282. Fitted estimator.
  283. """
  284. self._check_max_features(X)
  285. if self.prefit:
  286. try:
  287. check_is_fitted(self.estimator)
  288. except NotFittedError as exc:
  289. raise NotFittedError(
  290. "When `prefit=True`, `estimator` is expected to be a fitted "
  291. "estimator."
  292. ) from exc
  293. self.estimator_ = deepcopy(self.estimator)
  294. else:
  295. self.estimator_ = clone(self.estimator)
  296. self.estimator_.fit(X, y, **fit_params)
  297. if hasattr(self.estimator_, "feature_names_in_"):
  298. self.feature_names_in_ = self.estimator_.feature_names_in_
  299. else:
  300. self._check_feature_names(X, reset=True)
  301. return self
  302. @property
  303. def threshold_(self):
  304. """Threshold value used for feature selection."""
  305. scores = _get_feature_importances(
  306. estimator=self.estimator_,
  307. getter=self.importance_getter,
  308. transform_func="norm",
  309. norm_order=self.norm_order,
  310. )
  311. return _calculate_threshold(self.estimator, scores, self.threshold)
  312. @available_if(_estimator_has("partial_fit"))
  313. @_fit_context(
  314. # SelectFromModel.estimator is not validated yet
  315. prefer_skip_nested_validation=False
  316. )
  317. def partial_fit(self, X, y=None, **fit_params):
  318. """Fit the SelectFromModel meta-transformer only once.
  319. Parameters
  320. ----------
  321. X : array-like of shape (n_samples, n_features)
  322. The training input samples.
  323. y : array-like of shape (n_samples,), default=None
  324. The target values (integers that correspond to classes in
  325. classification, real numbers in regression).
  326. **fit_params : dict
  327. Other estimator specific parameters.
  328. Returns
  329. -------
  330. self : object
  331. Fitted estimator.
  332. """
  333. first_call = not hasattr(self, "estimator_")
  334. if first_call:
  335. self._check_max_features(X)
  336. if self.prefit:
  337. if first_call:
  338. try:
  339. check_is_fitted(self.estimator)
  340. except NotFittedError as exc:
  341. raise NotFittedError(
  342. "When `prefit=True`, `estimator` is expected to be a fitted "
  343. "estimator."
  344. ) from exc
  345. self.estimator_ = deepcopy(self.estimator)
  346. return self
  347. if first_call:
  348. self.estimator_ = clone(self.estimator)
  349. self.estimator_.partial_fit(X, y, **fit_params)
  350. if hasattr(self.estimator_, "feature_names_in_"):
  351. self.feature_names_in_ = self.estimator_.feature_names_in_
  352. else:
  353. self._check_feature_names(X, reset=first_call)
  354. return self
  355. @property
  356. def n_features_in_(self):
  357. """Number of features seen during `fit`."""
  358. # For consistency with other estimators we raise a AttributeError so
  359. # that hasattr() fails if the estimator isn't fitted.
  360. try:
  361. check_is_fitted(self)
  362. except NotFittedError as nfe:
  363. raise AttributeError(
  364. "{} object has no n_features_in_ attribute.".format(
  365. self.__class__.__name__
  366. )
  367. ) from nfe
  368. return self.estimator_.n_features_in_
  369. def _more_tags(self):
  370. return {"allow_nan": _safe_tags(self.estimator, key="allow_nan")}