_set_output.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from functools import wraps
  2. from scipy.sparse import issparse
  3. from .._config import get_config
  4. from . import check_pandas_support
  5. from ._available_if import available_if
  6. def _wrap_in_pandas_container(
  7. data_to_wrap,
  8. *,
  9. columns,
  10. index=None,
  11. ):
  12. """Create a Pandas DataFrame.
  13. If `data_to_wrap` is a DataFrame, then the `columns` and `index` will be changed
  14. inplace. If `data_to_wrap` is a ndarray, then a new DataFrame is created with
  15. `columns` and `index`.
  16. Parameters
  17. ----------
  18. data_to_wrap : {ndarray, dataframe}
  19. Data to be wrapped as pandas dataframe.
  20. columns : callable, ndarray, or None
  21. The column names or a callable that returns the column names. The
  22. callable is useful if the column names require some computation.
  23. If `columns` is a callable that raises an error, `columns` will have
  24. the same semantics as `None`. If `None` and `data_to_wrap` is already a
  25. dataframe, then the column names are not changed. If `None` and
  26. `data_to_wrap` is **not** a dataframe, then columns are
  27. `range(n_features)`.
  28. index : array-like, default=None
  29. Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame.
  30. Returns
  31. -------
  32. dataframe : DataFrame
  33. Container with column names or unchanged `output`.
  34. """
  35. if issparse(data_to_wrap):
  36. raise ValueError("Pandas output does not support sparse data.")
  37. if callable(columns):
  38. try:
  39. columns = columns()
  40. except Exception:
  41. columns = None
  42. pd = check_pandas_support("Setting output container to 'pandas'")
  43. if isinstance(data_to_wrap, pd.DataFrame):
  44. if columns is not None:
  45. data_to_wrap.columns = columns
  46. return data_to_wrap
  47. return pd.DataFrame(data_to_wrap, index=index, columns=columns, copy=False)
  48. def _get_output_config(method, estimator=None):
  49. """Get output config based on estimator and global configuration.
  50. Parameters
  51. ----------
  52. method : {"transform"}
  53. Estimator's method for which the output container is looked up.
  54. estimator : estimator instance or None
  55. Estimator to get the output configuration from. If `None`, check global
  56. configuration is used.
  57. Returns
  58. -------
  59. config : dict
  60. Dictionary with keys:
  61. - "dense": specifies the dense container for `method`. This can be
  62. `"default"` or `"pandas"`.
  63. """
  64. est_sklearn_output_config = getattr(estimator, "_sklearn_output_config", {})
  65. if method in est_sklearn_output_config:
  66. dense_config = est_sklearn_output_config[method]
  67. else:
  68. dense_config = get_config()[f"{method}_output"]
  69. if dense_config not in {"default", "pandas"}:
  70. raise ValueError(
  71. f"output config must be 'default' or 'pandas' got {dense_config}"
  72. )
  73. return {"dense": dense_config}
  74. def _wrap_data_with_container(method, data_to_wrap, original_input, estimator):
  75. """Wrap output with container based on an estimator's or global config.
  76. Parameters
  77. ----------
  78. method : {"transform"}
  79. Estimator's method to get container output for.
  80. data_to_wrap : {ndarray, dataframe}
  81. Data to wrap with container.
  82. original_input : {ndarray, dataframe}
  83. Original input of function.
  84. estimator : estimator instance
  85. Estimator with to get the output configuration from.
  86. Returns
  87. -------
  88. output : {ndarray, dataframe}
  89. If the output config is "default" or the estimator is not configured
  90. for wrapping return `data_to_wrap` unchanged.
  91. If the output config is "pandas", return `data_to_wrap` as a pandas
  92. DataFrame.
  93. """
  94. output_config = _get_output_config(method, estimator)
  95. if output_config["dense"] == "default" or not _auto_wrap_is_configured(estimator):
  96. return data_to_wrap
  97. def _is_pandas_df(X):
  98. """Return True if the X is a pandas dataframe.
  99. This is is backport from 1.4 in 1.3.1 for compatibility.
  100. """
  101. import sys
  102. if hasattr(X, "columns") and hasattr(X, "iloc"):
  103. # Likely a pandas DataFrame, we explicitly check the type to confirm.
  104. try:
  105. pd = sys.modules["pandas"]
  106. except KeyError:
  107. return False
  108. return isinstance(X, pd.DataFrame)
  109. return False
  110. # dense_config == "pandas"
  111. index = original_input.index if _is_pandas_df(original_input) else None
  112. return _wrap_in_pandas_container(
  113. data_to_wrap=data_to_wrap,
  114. index=index,
  115. columns=estimator.get_feature_names_out,
  116. )
  117. def _wrap_method_output(f, method):
  118. """Wrapper used by `_SetOutputMixin` to automatically wrap methods."""
  119. @wraps(f)
  120. def wrapped(self, X, *args, **kwargs):
  121. data_to_wrap = f(self, X, *args, **kwargs)
  122. if isinstance(data_to_wrap, tuple):
  123. # only wrap the first output for cross decomposition
  124. return_tuple = (
  125. _wrap_data_with_container(method, data_to_wrap[0], X, self),
  126. *data_to_wrap[1:],
  127. )
  128. # Support for namedtuples `_make` is a documented API for namedtuples:
  129. # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
  130. if hasattr(type(data_to_wrap), "_make"):
  131. return type(data_to_wrap)._make(return_tuple)
  132. return return_tuple
  133. return _wrap_data_with_container(method, data_to_wrap, X, self)
  134. return wrapped
  135. def _auto_wrap_is_configured(estimator):
  136. """Return True if estimator is configured for auto-wrapping the transform method.
  137. `_SetOutputMixin` sets `_sklearn_auto_wrap_output_keys` to `set()` if auto wrapping
  138. is manually disabled.
  139. """
  140. auto_wrap_output_keys = getattr(estimator, "_sklearn_auto_wrap_output_keys", set())
  141. return (
  142. hasattr(estimator, "get_feature_names_out")
  143. and "transform" in auto_wrap_output_keys
  144. )
  145. class _SetOutputMixin:
  146. """Mixin that dynamically wraps methods to return container based on config.
  147. Currently `_SetOutputMixin` wraps `transform` and `fit_transform` and configures
  148. it based on `set_output` of the global configuration.
  149. `set_output` is only defined if `get_feature_names_out` is defined and
  150. `auto_wrap_output_keys` is the default value.
  151. """
  152. def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
  153. super().__init_subclass__(**kwargs)
  154. # Dynamically wraps `transform` and `fit_transform` and configure it's
  155. # output based on `set_output`.
  156. if not (
  157. isinstance(auto_wrap_output_keys, tuple) or auto_wrap_output_keys is None
  158. ):
  159. raise ValueError("auto_wrap_output_keys must be None or a tuple of keys.")
  160. if auto_wrap_output_keys is None:
  161. cls._sklearn_auto_wrap_output_keys = set()
  162. return
  163. # Mapping from method to key in configurations
  164. method_to_key = {
  165. "transform": "transform",
  166. "fit_transform": "transform",
  167. }
  168. cls._sklearn_auto_wrap_output_keys = set()
  169. for method, key in method_to_key.items():
  170. if not hasattr(cls, method) or key not in auto_wrap_output_keys:
  171. continue
  172. cls._sklearn_auto_wrap_output_keys.add(key)
  173. # Only wrap methods defined by cls itself
  174. if method not in cls.__dict__:
  175. continue
  176. wrapped_method = _wrap_method_output(getattr(cls, method), key)
  177. setattr(cls, method, wrapped_method)
  178. @available_if(_auto_wrap_is_configured)
  179. def set_output(self, *, transform=None):
  180. """Set output container.
  181. See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
  182. for an example on how to use the API.
  183. Parameters
  184. ----------
  185. transform : {"default", "pandas"}, default=None
  186. Configure output of `transform` and `fit_transform`.
  187. - `"default"`: Default output format of a transformer
  188. - `"pandas"`: DataFrame output
  189. - `None`: Transform configuration is unchanged
  190. Returns
  191. -------
  192. self : estimator instance
  193. Estimator instance.
  194. """
  195. if transform is None:
  196. return self
  197. if not hasattr(self, "_sklearn_output_config"):
  198. self._sklearn_output_config = {}
  199. self._sklearn_output_config["transform"] = transform
  200. return self
  201. def _safe_set_output(estimator, *, transform=None):
  202. """Safely call estimator.set_output and error if it not available.
  203. This is used by meta-estimators to set the output for child estimators.
  204. Parameters
  205. ----------
  206. estimator : estimator instance
  207. Estimator instance.
  208. transform : {"default", "pandas"}, default=None
  209. Configure output of the following estimator's methods:
  210. - `"transform"`
  211. - `"fit_transform"`
  212. If `None`, this operation is a no-op.
  213. Returns
  214. -------
  215. estimator : estimator instance
  216. Estimator instance.
  217. """
  218. set_output_for_transform = (
  219. hasattr(estimator, "transform")
  220. or hasattr(estimator, "fit_transform")
  221. and transform is not None
  222. )
  223. if not set_output_for_transform:
  224. # If estimator can not transform, then `set_output` does not need to be
  225. # called.
  226. return
  227. if not hasattr(estimator, "set_output"):
  228. raise ValueError(
  229. f"Unable to configure output for {estimator} because `set_output` "
  230. "is not available."
  231. )
  232. return estimator.set_output(transform=transform)