__init__.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197
  1. """
  2. The :mod:`sklearn.utils` module includes various utilities.
  3. """
  4. import math
  5. import numbers
  6. import platform
  7. import struct
  8. import timeit
  9. import warnings
  10. from collections.abc import Sequence
  11. from contextlib import contextmanager, suppress
  12. from itertools import compress, islice
  13. import numpy as np
  14. from scipy.sparse import issparse
  15. from .. import get_config
  16. from ..exceptions import DataConversionWarning
  17. from . import _joblib, metadata_routing
  18. from ._bunch import Bunch
  19. from ._estimator_html_repr import estimator_html_repr
  20. from ._param_validation import Interval, validate_params
  21. from .class_weight import compute_class_weight, compute_sample_weight
  22. from .deprecation import deprecated
  23. from .discovery import all_estimators
  24. from .fixes import parse_version, threadpool_info
  25. from .murmurhash import murmurhash3_32
  26. from .validation import (
  27. _is_arraylike_not_scalar,
  28. as_float_array,
  29. assert_all_finite,
  30. check_array,
  31. check_consistent_length,
  32. check_random_state,
  33. check_scalar,
  34. check_symmetric,
  35. check_X_y,
  36. column_or_1d,
  37. indexable,
  38. )
  39. # Do not deprecate parallel_backend and register_parallel_backend as they are
  40. # needed to tune `scikit-learn` behavior and have different effect if called
  41. # from the vendored version or or the site-package version. The other are
  42. # utilities that are independent of scikit-learn so they are not part of
  43. # scikit-learn public API.
  44. parallel_backend = _joblib.parallel_backend
  45. register_parallel_backend = _joblib.register_parallel_backend
  46. __all__ = [
  47. "murmurhash3_32",
  48. "as_float_array",
  49. "assert_all_finite",
  50. "check_array",
  51. "check_random_state",
  52. "compute_class_weight",
  53. "compute_sample_weight",
  54. "column_or_1d",
  55. "check_consistent_length",
  56. "check_X_y",
  57. "check_scalar",
  58. "indexable",
  59. "check_symmetric",
  60. "indices_to_mask",
  61. "deprecated",
  62. "parallel_backend",
  63. "register_parallel_backend",
  64. "resample",
  65. "shuffle",
  66. "check_matplotlib_support",
  67. "all_estimators",
  68. "DataConversionWarning",
  69. "estimator_html_repr",
  70. "Bunch",
  71. "metadata_routing",
  72. ]
  73. IS_PYPY = platform.python_implementation() == "PyPy"
  74. _IS_32BIT = 8 * struct.calcsize("P") == 32
  75. def _in_unstable_openblas_configuration():
  76. """Return True if in an unstable configuration for OpenBLAS"""
  77. # Import libraries which might load OpenBLAS.
  78. import numpy # noqa
  79. import scipy # noqa
  80. modules_info = threadpool_info()
  81. open_blas_used = any(info["internal_api"] == "openblas" for info in modules_info)
  82. if not open_blas_used:
  83. return False
  84. # OpenBLAS 0.3.16 fixed unstability for arm64, see:
  85. # https://github.com/xianyi/OpenBLAS/blob/1b6db3dbba672b4f8af935bd43a1ff6cff4d20b7/Changelog.txt#L56-L58 # noqa
  86. openblas_arm64_stable_version = parse_version("0.3.16")
  87. for info in modules_info:
  88. if info["internal_api"] != "openblas":
  89. continue
  90. openblas_version = info.get("version")
  91. openblas_architecture = info.get("architecture")
  92. if openblas_version is None or openblas_architecture is None:
  93. # Cannot be sure that OpenBLAS is good enough. Assume unstable:
  94. return True
  95. if (
  96. openblas_architecture == "neoversen1"
  97. and parse_version(openblas_version) < openblas_arm64_stable_version
  98. ):
  99. # See discussions in https://github.com/numpy/numpy/issues/19411
  100. return True
  101. return False
  102. def safe_mask(X, mask):
  103. """Return a mask which is safe to use on X.
  104. Parameters
  105. ----------
  106. X : {array-like, sparse matrix}
  107. Data on which to apply mask.
  108. mask : ndarray
  109. Mask to be used on X.
  110. Returns
  111. -------
  112. mask : ndarray
  113. Array that is safe to use on X.
  114. """
  115. mask = np.asarray(mask)
  116. if np.issubdtype(mask.dtype, np.signedinteger):
  117. return mask
  118. if hasattr(X, "toarray"):
  119. ind = np.arange(mask.shape[0])
  120. mask = ind[mask]
  121. return mask
  122. def axis0_safe_slice(X, mask, len_mask):
  123. """Return a mask which is safer to use on X than safe_mask.
  124. This mask is safer than safe_mask since it returns an
  125. empty array, when a sparse matrix is sliced with a boolean mask
  126. with all False, instead of raising an unhelpful error in older
  127. versions of SciPy.
  128. See: https://github.com/scipy/scipy/issues/5361
  129. Also note that we can avoid doing the dot product by checking if
  130. the len_mask is not zero in _huber_loss_and_gradient but this
  131. is not going to be the bottleneck, since the number of outliers
  132. and non_outliers are typically non-zero and it makes the code
  133. tougher to follow.
  134. Parameters
  135. ----------
  136. X : {array-like, sparse matrix}
  137. Data on which to apply mask.
  138. mask : ndarray
  139. Mask to be used on X.
  140. len_mask : int
  141. The length of the mask.
  142. Returns
  143. -------
  144. mask : ndarray
  145. Array that is safe to use on X.
  146. """
  147. if len_mask != 0:
  148. return X[safe_mask(X, mask), :]
  149. return np.zeros(shape=(0, X.shape[1]))
  150. def _array_indexing(array, key, key_dtype, axis):
  151. """Index an array or scipy.sparse consistently across NumPy version."""
  152. if issparse(array) and key_dtype == "bool":
  153. key = np.asarray(key)
  154. if isinstance(key, tuple):
  155. key = list(key)
  156. return array[key] if axis == 0 else array[:, key]
  157. def _pandas_indexing(X, key, key_dtype, axis):
  158. """Index a pandas dataframe or a series."""
  159. if _is_arraylike_not_scalar(key):
  160. key = np.asarray(key)
  161. if key_dtype == "int" and not (isinstance(key, slice) or np.isscalar(key)):
  162. # using take() instead of iloc[] ensures the return value is a "proper"
  163. # copy that will not raise SettingWithCopyWarning
  164. return X.take(key, axis=axis)
  165. else:
  166. # check whether we should index with loc or iloc
  167. indexer = X.iloc if key_dtype == "int" else X.loc
  168. return indexer[:, key] if axis else indexer[key]
  169. def _list_indexing(X, key, key_dtype):
  170. """Index a Python list."""
  171. if np.isscalar(key) or isinstance(key, slice):
  172. # key is a slice or a scalar
  173. return X[key]
  174. if key_dtype == "bool":
  175. # key is a boolean array-like
  176. return list(compress(X, key))
  177. # key is a integer array-like of key
  178. return [X[idx] for idx in key]
  179. def _determine_key_type(key, accept_slice=True):
  180. """Determine the data type of key.
  181. Parameters
  182. ----------
  183. key : scalar, slice or array-like
  184. The key from which we want to infer the data type.
  185. accept_slice : bool, default=True
  186. Whether or not to raise an error if the key is a slice.
  187. Returns
  188. -------
  189. dtype : {'int', 'str', 'bool', None}
  190. Returns the data type of key.
  191. """
  192. err_msg = (
  193. "No valid specification of the columns. Only a scalar, list or "
  194. "slice of all integers or all strings, or boolean mask is "
  195. "allowed"
  196. )
  197. dtype_to_str = {int: "int", str: "str", bool: "bool", np.bool_: "bool"}
  198. array_dtype_to_str = {
  199. "i": "int",
  200. "u": "int",
  201. "b": "bool",
  202. "O": "str",
  203. "U": "str",
  204. "S": "str",
  205. }
  206. if key is None:
  207. return None
  208. if isinstance(key, tuple(dtype_to_str.keys())):
  209. try:
  210. return dtype_to_str[type(key)]
  211. except KeyError:
  212. raise ValueError(err_msg)
  213. if isinstance(key, slice):
  214. if not accept_slice:
  215. raise TypeError(
  216. "Only array-like or scalar are supported. A Python slice was given."
  217. )
  218. if key.start is None and key.stop is None:
  219. return None
  220. key_start_type = _determine_key_type(key.start)
  221. key_stop_type = _determine_key_type(key.stop)
  222. if key_start_type is not None and key_stop_type is not None:
  223. if key_start_type != key_stop_type:
  224. raise ValueError(err_msg)
  225. if key_start_type is not None:
  226. return key_start_type
  227. return key_stop_type
  228. if isinstance(key, (list, tuple)):
  229. unique_key = set(key)
  230. key_type = {_determine_key_type(elt) for elt in unique_key}
  231. if not key_type:
  232. return None
  233. if len(key_type) != 1:
  234. raise ValueError(err_msg)
  235. return key_type.pop()
  236. if hasattr(key, "dtype"):
  237. try:
  238. return array_dtype_to_str[key.dtype.kind]
  239. except KeyError:
  240. raise ValueError(err_msg)
  241. raise ValueError(err_msg)
  242. def _safe_indexing(X, indices, *, axis=0):
  243. """Return rows, items or columns of X using indices.
  244. .. warning::
  245. This utility is documented, but **private**. This means that
  246. backward compatibility might be broken without any deprecation
  247. cycle.
  248. Parameters
  249. ----------
  250. X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series
  251. Data from which to sample rows, items or columns. `list` are only
  252. supported when `axis=0`.
  253. indices : bool, int, str, slice, array-like
  254. - If `axis=0`, boolean and integer array-like, integer slice,
  255. and scalar integer are supported.
  256. - If `axis=1`:
  257. - to select a single column, `indices` can be of `int` type for
  258. all `X` types and `str` only for dataframe. The selected subset
  259. will be 1D, unless `X` is a sparse matrix in which case it will
  260. be 2D.
  261. - to select multiples columns, `indices` can be one of the
  262. following: `list`, `array`, `slice`. The type used in
  263. these containers can be one of the following: `int`, 'bool' and
  264. `str`. However, `str` is only supported when `X` is a dataframe.
  265. The selected subset will be 2D.
  266. axis : int, default=0
  267. The axis along which `X` will be subsampled. `axis=0` will select
  268. rows while `axis=1` will select columns.
  269. Returns
  270. -------
  271. subset
  272. Subset of X on axis 0 or 1.
  273. Notes
  274. -----
  275. CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are
  276. not supported.
  277. """
  278. if indices is None:
  279. return X
  280. if axis not in (0, 1):
  281. raise ValueError(
  282. "'axis' should be either 0 (to index rows) or 1 (to index "
  283. " column). Got {} instead.".format(axis)
  284. )
  285. indices_dtype = _determine_key_type(indices)
  286. if axis == 0 and indices_dtype == "str":
  287. raise ValueError("String indexing is not supported with 'axis=0'")
  288. if axis == 1 and X.ndim != 2:
  289. raise ValueError(
  290. "'X' should be a 2D NumPy array, 2D sparse matrix or pandas "
  291. "dataframe when indexing the columns (i.e. 'axis=1'). "
  292. "Got {} instead with {} dimension(s).".format(type(X), X.ndim)
  293. )
  294. if axis == 1 and indices_dtype == "str" and not hasattr(X, "loc"):
  295. raise ValueError(
  296. "Specifying the columns using strings is only supported for "
  297. "pandas DataFrames"
  298. )
  299. if hasattr(X, "iloc"):
  300. return _pandas_indexing(X, indices, indices_dtype, axis=axis)
  301. elif hasattr(X, "shape"):
  302. return _array_indexing(X, indices, indices_dtype, axis=axis)
  303. else:
  304. return _list_indexing(X, indices, indices_dtype)
  305. def _safe_assign(X, values, *, row_indexer=None, column_indexer=None):
  306. """Safe assignment to a numpy array, sparse matrix, or pandas dataframe.
  307. Parameters
  308. ----------
  309. X : {ndarray, sparse-matrix, dataframe}
  310. Array to be modified. It is expected to be 2-dimensional.
  311. values : ndarray
  312. The values to be assigned to `X`.
  313. row_indexer : array-like, dtype={int, bool}, default=None
  314. A 1-dimensional array to select the rows of interest. If `None`, all
  315. rows are selected.
  316. column_indexer : array-like, dtype={int, bool}, default=None
  317. A 1-dimensional array to select the columns of interest. If `None`, all
  318. columns are selected.
  319. """
  320. row_indexer = slice(None, None, None) if row_indexer is None else row_indexer
  321. column_indexer = (
  322. slice(None, None, None) if column_indexer is None else column_indexer
  323. )
  324. if hasattr(X, "iloc"): # pandas dataframe
  325. with warnings.catch_warnings():
  326. # pandas >= 1.5 raises a warning when using iloc to set values in a column
  327. # that does not have the same type as the column being set. It happens
  328. # for instance when setting a categorical column with a string.
  329. # In the future the behavior won't change and the warning should disappear.
  330. # TODO(1.3): check if the warning is still raised or remove the filter.
  331. warnings.simplefilter("ignore", FutureWarning)
  332. X.iloc[row_indexer, column_indexer] = values
  333. else: # numpy array or sparse matrix
  334. X[row_indexer, column_indexer] = values
  335. def _get_column_indices(X, key):
  336. """Get feature column indices for input data X and key.
  337. For accepted values of `key`, see the docstring of
  338. :func:`_safe_indexing`.
  339. """
  340. n_columns = X.shape[1]
  341. key_dtype = _determine_key_type(key)
  342. if isinstance(key, (list, tuple)) and not key:
  343. # we get an empty list
  344. return []
  345. elif key_dtype in ("bool", "int"):
  346. # Convert key into positive indexes
  347. try:
  348. idx = _safe_indexing(np.arange(n_columns), key)
  349. except IndexError as e:
  350. raise ValueError(
  351. "all features must be in [0, {}] or [-{}, 0]".format(
  352. n_columns - 1, n_columns
  353. )
  354. ) from e
  355. return np.atleast_1d(idx).tolist()
  356. elif key_dtype == "str":
  357. try:
  358. all_columns = X.columns
  359. except AttributeError:
  360. raise ValueError(
  361. "Specifying the columns using strings is only "
  362. "supported for pandas DataFrames"
  363. )
  364. if isinstance(key, str):
  365. columns = [key]
  366. elif isinstance(key, slice):
  367. start, stop = key.start, key.stop
  368. if start is not None:
  369. start = all_columns.get_loc(start)
  370. if stop is not None:
  371. # pandas indexing with strings is endpoint included
  372. stop = all_columns.get_loc(stop) + 1
  373. else:
  374. stop = n_columns + 1
  375. return list(islice(range(n_columns), start, stop))
  376. else:
  377. columns = list(key)
  378. try:
  379. column_indices = []
  380. for col in columns:
  381. col_idx = all_columns.get_loc(col)
  382. if not isinstance(col_idx, numbers.Integral):
  383. raise ValueError(
  384. f"Selected columns, {columns}, are not unique in dataframe"
  385. )
  386. column_indices.append(col_idx)
  387. except KeyError as e:
  388. raise ValueError("A given column is not a column of the dataframe") from e
  389. return column_indices
  390. else:
  391. raise ValueError(
  392. "No valid specification of the columns. Only a "
  393. "scalar, list or slice of all integers or all "
  394. "strings, or boolean mask is allowed"
  395. )
  396. @validate_params(
  397. {
  398. "replace": ["boolean"],
  399. "n_samples": [Interval(numbers.Integral, 1, None, closed="left"), None],
  400. "random_state": ["random_state"],
  401. "stratify": ["array-like", None],
  402. },
  403. prefer_skip_nested_validation=True,
  404. )
  405. def resample(*arrays, replace=True, n_samples=None, random_state=None, stratify=None):
  406. """Resample arrays or sparse matrices in a consistent way.
  407. The default strategy implements one step of the bootstrapping
  408. procedure.
  409. Parameters
  410. ----------
  411. *arrays : sequence of array-like of shape (n_samples,) or \
  412. (n_samples, n_outputs)
  413. Indexable data-structures can be arrays, lists, dataframes or scipy
  414. sparse matrices with consistent first dimension.
  415. replace : bool, default=True
  416. Implements resampling with replacement. If False, this will implement
  417. (sliced) random permutations.
  418. n_samples : int, default=None
  419. Number of samples to generate. If left to None this is
  420. automatically set to the first dimension of the arrays.
  421. If replace is False it should not be larger than the length of
  422. arrays.
  423. random_state : int, RandomState instance or None, default=None
  424. Determines random number generation for shuffling
  425. the data.
  426. Pass an int for reproducible results across multiple function calls.
  427. See :term:`Glossary <random_state>`.
  428. stratify : array-like of shape (n_samples,) or (n_samples, n_outputs), \
  429. default=None
  430. If not None, data is split in a stratified fashion, using this as
  431. the class labels.
  432. Returns
  433. -------
  434. resampled_arrays : sequence of array-like of shape (n_samples,) or \
  435. (n_samples, n_outputs)
  436. Sequence of resampled copies of the collections. The original arrays
  437. are not impacted.
  438. See Also
  439. --------
  440. shuffle : Shuffle arrays or sparse matrices in a consistent way.
  441. Examples
  442. --------
  443. It is possible to mix sparse and dense arrays in the same run::
  444. >>> import numpy as np
  445. >>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
  446. >>> y = np.array([0, 1, 2])
  447. >>> from scipy.sparse import coo_matrix
  448. >>> X_sparse = coo_matrix(X)
  449. >>> from sklearn.utils import resample
  450. >>> X, X_sparse, y = resample(X, X_sparse, y, random_state=0)
  451. >>> X
  452. array([[1., 0.],
  453. [2., 1.],
  454. [1., 0.]])
  455. >>> X_sparse
  456. <3x2 sparse matrix of type '<... 'numpy.float64'>'
  457. with 4 stored elements in Compressed Sparse Row format>
  458. >>> X_sparse.toarray()
  459. array([[1., 0.],
  460. [2., 1.],
  461. [1., 0.]])
  462. >>> y
  463. array([0, 1, 0])
  464. >>> resample(y, n_samples=2, random_state=0)
  465. array([0, 1])
  466. Example using stratification::
  467. >>> y = [0, 0, 1, 1, 1, 1, 1, 1, 1]
  468. >>> resample(y, n_samples=5, replace=False, stratify=y,
  469. ... random_state=0)
  470. [1, 1, 1, 0, 1]
  471. """
  472. max_n_samples = n_samples
  473. random_state = check_random_state(random_state)
  474. if len(arrays) == 0:
  475. return None
  476. first = arrays[0]
  477. n_samples = first.shape[0] if hasattr(first, "shape") else len(first)
  478. if max_n_samples is None:
  479. max_n_samples = n_samples
  480. elif (max_n_samples > n_samples) and (not replace):
  481. raise ValueError(
  482. "Cannot sample %d out of arrays with dim %d when replace is False"
  483. % (max_n_samples, n_samples)
  484. )
  485. check_consistent_length(*arrays)
  486. if stratify is None:
  487. if replace:
  488. indices = random_state.randint(0, n_samples, size=(max_n_samples,))
  489. else:
  490. indices = np.arange(n_samples)
  491. random_state.shuffle(indices)
  492. indices = indices[:max_n_samples]
  493. else:
  494. # Code adapted from StratifiedShuffleSplit()
  495. y = check_array(stratify, ensure_2d=False, dtype=None)
  496. if y.ndim == 2:
  497. # for multi-label y, map each distinct row to a string repr
  498. # using join because str(row) uses an ellipsis if len(row) > 1000
  499. y = np.array([" ".join(row.astype("str")) for row in y])
  500. classes, y_indices = np.unique(y, return_inverse=True)
  501. n_classes = classes.shape[0]
  502. class_counts = np.bincount(y_indices)
  503. # Find the sorted list of instances for each class:
  504. # (np.unique above performs a sort, so code is O(n logn) already)
  505. class_indices = np.split(
  506. np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1]
  507. )
  508. n_i = _approximate_mode(class_counts, max_n_samples, random_state)
  509. indices = []
  510. for i in range(n_classes):
  511. indices_i = random_state.choice(class_indices[i], n_i[i], replace=replace)
  512. indices.extend(indices_i)
  513. indices = random_state.permutation(indices)
  514. # convert sparse matrices to CSR for row-based indexing
  515. arrays = [a.tocsr() if issparse(a) else a for a in arrays]
  516. resampled_arrays = [_safe_indexing(a, indices) for a in arrays]
  517. if len(resampled_arrays) == 1:
  518. # syntactic sugar for the unit argument case
  519. return resampled_arrays[0]
  520. else:
  521. return resampled_arrays
  522. def shuffle(*arrays, random_state=None, n_samples=None):
  523. """Shuffle arrays or sparse matrices in a consistent way.
  524. This is a convenience alias to ``resample(*arrays, replace=False)`` to do
  525. random permutations of the collections.
  526. Parameters
  527. ----------
  528. *arrays : sequence of indexable data-structures
  529. Indexable data-structures can be arrays, lists, dataframes or scipy
  530. sparse matrices with consistent first dimension.
  531. random_state : int, RandomState instance or None, default=None
  532. Determines random number generation for shuffling
  533. the data.
  534. Pass an int for reproducible results across multiple function calls.
  535. See :term:`Glossary <random_state>`.
  536. n_samples : int, default=None
  537. Number of samples to generate. If left to None this is
  538. automatically set to the first dimension of the arrays. It should
  539. not be larger than the length of arrays.
  540. Returns
  541. -------
  542. shuffled_arrays : sequence of indexable data-structures
  543. Sequence of shuffled copies of the collections. The original arrays
  544. are not impacted.
  545. See Also
  546. --------
  547. resample : Resample arrays or sparse matrices in a consistent way.
  548. Examples
  549. --------
  550. It is possible to mix sparse and dense arrays in the same run::
  551. >>> import numpy as np
  552. >>> X = np.array([[1., 0.], [2., 1.], [0., 0.]])
  553. >>> y = np.array([0, 1, 2])
  554. >>> from scipy.sparse import coo_matrix
  555. >>> X_sparse = coo_matrix(X)
  556. >>> from sklearn.utils import shuffle
  557. >>> X, X_sparse, y = shuffle(X, X_sparse, y, random_state=0)
  558. >>> X
  559. array([[0., 0.],
  560. [2., 1.],
  561. [1., 0.]])
  562. >>> X_sparse
  563. <3x2 sparse matrix of type '<... 'numpy.float64'>'
  564. with 3 stored elements in Compressed Sparse Row format>
  565. >>> X_sparse.toarray()
  566. array([[0., 0.],
  567. [2., 1.],
  568. [1., 0.]])
  569. >>> y
  570. array([2, 1, 0])
  571. >>> shuffle(y, n_samples=2, random_state=0)
  572. array([0, 1])
  573. """
  574. return resample(
  575. *arrays, replace=False, n_samples=n_samples, random_state=random_state
  576. )
  577. def safe_sqr(X, *, copy=True):
  578. """Element wise squaring of array-likes and sparse matrices.
  579. Parameters
  580. ----------
  581. X : {array-like, ndarray, sparse matrix}
  582. copy : bool, default=True
  583. Whether to create a copy of X and operate on it or to perform
  584. inplace computation (default behaviour).
  585. Returns
  586. -------
  587. X ** 2 : element wise square
  588. Return the element-wise square of the input.
  589. """
  590. X = check_array(X, accept_sparse=["csr", "csc", "coo"], ensure_2d=False)
  591. if issparse(X):
  592. if copy:
  593. X = X.copy()
  594. X.data **= 2
  595. else:
  596. if copy:
  597. X = X**2
  598. else:
  599. X **= 2
  600. return X
  601. def _chunk_generator(gen, chunksize):
  602. """Chunk generator, ``gen`` into lists of length ``chunksize``. The last
  603. chunk may have a length less than ``chunksize``."""
  604. while True:
  605. chunk = list(islice(gen, chunksize))
  606. if chunk:
  607. yield chunk
  608. else:
  609. return
  610. @validate_params(
  611. {
  612. "n": [Interval(numbers.Integral, 1, None, closed="left")],
  613. "batch_size": [Interval(numbers.Integral, 1, None, closed="left")],
  614. "min_batch_size": [Interval(numbers.Integral, 0, None, closed="left")],
  615. },
  616. prefer_skip_nested_validation=True,
  617. )
  618. def gen_batches(n, batch_size, *, min_batch_size=0):
  619. """Generator to create slices containing `batch_size` elements from 0 to `n`.
  620. The last slice may contain less than `batch_size` elements, when
  621. `batch_size` does not divide `n`.
  622. Parameters
  623. ----------
  624. n : int
  625. Size of the sequence.
  626. batch_size : int
  627. Number of elements in each batch.
  628. min_batch_size : int, default=0
  629. Minimum number of elements in each batch.
  630. Yields
  631. ------
  632. slice of `batch_size` elements
  633. See Also
  634. --------
  635. gen_even_slices: Generator to create n_packs slices going up to n.
  636. Examples
  637. --------
  638. >>> from sklearn.utils import gen_batches
  639. >>> list(gen_batches(7, 3))
  640. [slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
  641. >>> list(gen_batches(6, 3))
  642. [slice(0, 3, None), slice(3, 6, None)]
  643. >>> list(gen_batches(2, 3))
  644. [slice(0, 2, None)]
  645. >>> list(gen_batches(7, 3, min_batch_size=0))
  646. [slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
  647. >>> list(gen_batches(7, 3, min_batch_size=2))
  648. [slice(0, 3, None), slice(3, 7, None)]
  649. """
  650. start = 0
  651. for _ in range(int(n // batch_size)):
  652. end = start + batch_size
  653. if end + min_batch_size > n:
  654. continue
  655. yield slice(start, end)
  656. start = end
  657. if start < n:
  658. yield slice(start, n)
  659. def gen_even_slices(n, n_packs, *, n_samples=None):
  660. """Generator to create `n_packs` evenly spaced slices going up to `n`.
  661. If `n_packs` does not divide `n`, except for the first `n % n_packs`
  662. slices, remaining slices may contain fewer elements.
  663. Parameters
  664. ----------
  665. n : int
  666. Size of the sequence.
  667. n_packs : int
  668. Number of slices to generate.
  669. n_samples : int, default=None
  670. Number of samples. Pass `n_samples` when the slices are to be used for
  671. sparse matrix indexing; slicing off-the-end raises an exception, while
  672. it works for NumPy arrays.
  673. Yields
  674. ------
  675. `slice` representing a set of indices from 0 to n.
  676. See Also
  677. --------
  678. gen_batches: Generator to create slices containing batch_size elements
  679. from 0 to n.
  680. Examples
  681. --------
  682. >>> from sklearn.utils import gen_even_slices
  683. >>> list(gen_even_slices(10, 1))
  684. [slice(0, 10, None)]
  685. >>> list(gen_even_slices(10, 10))
  686. [slice(0, 1, None), slice(1, 2, None), ..., slice(9, 10, None)]
  687. >>> list(gen_even_slices(10, 5))
  688. [slice(0, 2, None), slice(2, 4, None), ..., slice(8, 10, None)]
  689. >>> list(gen_even_slices(10, 3))
  690. [slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)]
  691. """
  692. start = 0
  693. if n_packs < 1:
  694. raise ValueError("gen_even_slices got n_packs=%s, must be >=1" % n_packs)
  695. for pack_num in range(n_packs):
  696. this_n = n // n_packs
  697. if pack_num < n % n_packs:
  698. this_n += 1
  699. if this_n > 0:
  700. end = start + this_n
  701. if n_samples is not None:
  702. end = min(n_samples, end)
  703. yield slice(start, end, None)
  704. start = end
  705. def tosequence(x):
  706. """Cast iterable x to a Sequence, avoiding a copy if possible.
  707. Parameters
  708. ----------
  709. x : iterable
  710. The iterable to be converted.
  711. Returns
  712. -------
  713. x : Sequence
  714. If `x` is a NumPy array, it returns it as a `ndarray`. If `x`
  715. is a `Sequence`, `x` is returned as-is. If `x` is from any other
  716. type, `x` is returned casted as a list.
  717. """
  718. if isinstance(x, np.ndarray):
  719. return np.asarray(x)
  720. elif isinstance(x, Sequence):
  721. return x
  722. else:
  723. return list(x)
  724. def _to_object_array(sequence):
  725. """Convert sequence to a 1-D NumPy array of object dtype.
  726. numpy.array constructor has a similar use but it's output
  727. is ambiguous. It can be 1-D NumPy array of object dtype if
  728. the input is a ragged array, but if the input is a list of
  729. equal length arrays, then the output is a 2D numpy.array.
  730. _to_object_array solves this ambiguity by guarantying that
  731. the output is a 1-D NumPy array of objects for any input.
  732. Parameters
  733. ----------
  734. sequence : array-like of shape (n_elements,)
  735. The sequence to be converted.
  736. Returns
  737. -------
  738. out : ndarray of shape (n_elements,), dtype=object
  739. The converted sequence into a 1-D NumPy array of object dtype.
  740. Examples
  741. --------
  742. >>> import numpy as np
  743. >>> from sklearn.utils import _to_object_array
  744. >>> _to_object_array([np.array([0]), np.array([1])])
  745. array([array([0]), array([1])], dtype=object)
  746. >>> _to_object_array([np.array([0]), np.array([1, 2])])
  747. array([array([0]), array([1, 2])], dtype=object)
  748. >>> _to_object_array([np.array([0]), np.array([1, 2])])
  749. array([array([0]), array([1, 2])], dtype=object)
  750. """
  751. out = np.empty(len(sequence), dtype=object)
  752. out[:] = sequence
  753. return out
  754. def indices_to_mask(indices, mask_length):
  755. """Convert list of indices to boolean mask.
  756. Parameters
  757. ----------
  758. indices : list-like
  759. List of integers treated as indices.
  760. mask_length : int
  761. Length of boolean mask to be generated.
  762. This parameter must be greater than max(indices).
  763. Returns
  764. -------
  765. mask : 1d boolean nd-array
  766. Boolean array that is True where indices are present, else False.
  767. Examples
  768. --------
  769. >>> from sklearn.utils import indices_to_mask
  770. >>> indices = [1, 2 , 3, 4]
  771. >>> indices_to_mask(indices, 5)
  772. array([False, True, True, True, True])
  773. """
  774. if mask_length <= np.max(indices):
  775. raise ValueError("mask_length must be greater than max(indices)")
  776. mask = np.zeros(mask_length, dtype=bool)
  777. mask[indices] = True
  778. return mask
  779. def _message_with_time(source, message, time):
  780. """Create one line message for logging purposes.
  781. Parameters
  782. ----------
  783. source : str
  784. String indicating the source or the reference of the message.
  785. message : str
  786. Short message.
  787. time : int
  788. Time in seconds.
  789. """
  790. start_message = "[%s] " % source
  791. # adapted from joblib.logger.short_format_time without the Windows -.1s
  792. # adjustment
  793. if time > 60:
  794. time_str = "%4.1fmin" % (time / 60)
  795. else:
  796. time_str = " %5.1fs" % time
  797. end_message = " %s, total=%s" % (message, time_str)
  798. dots_len = 70 - len(start_message) - len(end_message)
  799. return "%s%s%s" % (start_message, dots_len * ".", end_message)
  800. @contextmanager
  801. def _print_elapsed_time(source, message=None):
  802. """Log elapsed time to stdout when the context is exited.
  803. Parameters
  804. ----------
  805. source : str
  806. String indicating the source or the reference of the message.
  807. message : str, default=None
  808. Short message. If None, nothing will be printed.
  809. Returns
  810. -------
  811. context_manager
  812. Prints elapsed time upon exit if verbose.
  813. """
  814. if message is None:
  815. yield
  816. else:
  817. start = timeit.default_timer()
  818. yield
  819. print(_message_with_time(source, message, timeit.default_timer() - start))
  820. def get_chunk_n_rows(row_bytes, *, max_n_rows=None, working_memory=None):
  821. """Calculate how many rows can be processed within `working_memory`.
  822. Parameters
  823. ----------
  824. row_bytes : int
  825. The expected number of bytes of memory that will be consumed
  826. during the processing of each row.
  827. max_n_rows : int, default=None
  828. The maximum return value.
  829. working_memory : int or float, default=None
  830. The number of rows to fit inside this number of MiB will be
  831. returned. When None (default), the value of
  832. ``sklearn.get_config()['working_memory']`` is used.
  833. Returns
  834. -------
  835. int
  836. The number of rows which can be processed within `working_memory`.
  837. Warns
  838. -----
  839. Issues a UserWarning if `row_bytes exceeds `working_memory` MiB.
  840. """
  841. if working_memory is None:
  842. working_memory = get_config()["working_memory"]
  843. chunk_n_rows = int(working_memory * (2**20) // row_bytes)
  844. if max_n_rows is not None:
  845. chunk_n_rows = min(chunk_n_rows, max_n_rows)
  846. if chunk_n_rows < 1:
  847. warnings.warn(
  848. "Could not adhere to working_memory config. "
  849. "Currently %.0fMiB, %.0fMiB required."
  850. % (working_memory, np.ceil(row_bytes * 2**-20))
  851. )
  852. chunk_n_rows = 1
  853. return chunk_n_rows
  854. def _is_pandas_na(x):
  855. """Test if x is pandas.NA.
  856. We intentionally do not use this function to return `True` for `pd.NA` in
  857. `is_scalar_nan`, because estimators that support `pd.NA` are the exception
  858. rather than the rule at the moment. When `pd.NA` is more universally
  859. supported, we may reconsider this decision.
  860. Parameters
  861. ----------
  862. x : any type
  863. Returns
  864. -------
  865. boolean
  866. """
  867. with suppress(ImportError):
  868. from pandas import NA
  869. return x is NA
  870. return False
  871. def is_scalar_nan(x):
  872. """Test if x is NaN.
  873. This function is meant to overcome the issue that np.isnan does not allow
  874. non-numerical types as input, and that np.nan is not float('nan').
  875. Parameters
  876. ----------
  877. x : any type
  878. Any scalar value.
  879. Returns
  880. -------
  881. bool
  882. Returns true if x is NaN, and false otherwise.
  883. Examples
  884. --------
  885. >>> import numpy as np
  886. >>> from sklearn.utils import is_scalar_nan
  887. >>> is_scalar_nan(np.nan)
  888. True
  889. >>> is_scalar_nan(float("nan"))
  890. True
  891. >>> is_scalar_nan(None)
  892. False
  893. >>> is_scalar_nan("")
  894. False
  895. >>> is_scalar_nan([np.nan])
  896. False
  897. """
  898. return isinstance(x, numbers.Real) and math.isnan(x)
  899. def _approximate_mode(class_counts, n_draws, rng):
  900. """Computes approximate mode of multivariate hypergeometric.
  901. This is an approximation to the mode of the multivariate
  902. hypergeometric given by class_counts and n_draws.
  903. It shouldn't be off by more than one.
  904. It is the mostly likely outcome of drawing n_draws many
  905. samples from the population given by class_counts.
  906. Parameters
  907. ----------
  908. class_counts : ndarray of int
  909. Population per class.
  910. n_draws : int
  911. Number of draws (samples to draw) from the overall population.
  912. rng : random state
  913. Used to break ties.
  914. Returns
  915. -------
  916. sampled_classes : ndarray of int
  917. Number of samples drawn from each class.
  918. np.sum(sampled_classes) == n_draws
  919. Examples
  920. --------
  921. >>> import numpy as np
  922. >>> from sklearn.utils import _approximate_mode
  923. >>> _approximate_mode(class_counts=np.array([4, 2]), n_draws=3, rng=0)
  924. array([2, 1])
  925. >>> _approximate_mode(class_counts=np.array([5, 2]), n_draws=4, rng=0)
  926. array([3, 1])
  927. >>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
  928. ... n_draws=2, rng=0)
  929. array([0, 1, 1, 0])
  930. >>> _approximate_mode(class_counts=np.array([2, 2, 2, 1]),
  931. ... n_draws=2, rng=42)
  932. array([1, 1, 0, 0])
  933. """
  934. rng = check_random_state(rng)
  935. # this computes a bad approximation to the mode of the
  936. # multivariate hypergeometric given by class_counts and n_draws
  937. continuous = class_counts / class_counts.sum() * n_draws
  938. # floored means we don't overshoot n_samples, but probably undershoot
  939. floored = np.floor(continuous)
  940. # we add samples according to how much "left over" probability
  941. # they had, until we arrive at n_samples
  942. need_to_add = int(n_draws - floored.sum())
  943. if need_to_add > 0:
  944. remainder = continuous - floored
  945. values = np.sort(np.unique(remainder))[::-1]
  946. # add according to remainder, but break ties
  947. # randomly to avoid biases
  948. for value in values:
  949. (inds,) = np.where(remainder == value)
  950. # if we need_to_add less than what's in inds
  951. # we draw randomly from them.
  952. # if we need to add more, we add them all and
  953. # go to the next value
  954. add_now = min(len(inds), need_to_add)
  955. inds = rng.choice(inds, size=add_now, replace=False)
  956. floored[inds] += 1
  957. need_to_add -= add_now
  958. if need_to_add == 0:
  959. break
  960. return floored.astype(int)
  961. def check_matplotlib_support(caller_name):
  962. """Raise ImportError with detailed error message if mpl is not installed.
  963. Plot utilities like any of the Display's plotting functions should lazily import
  964. matplotlib and call this helper before any computation.
  965. Parameters
  966. ----------
  967. caller_name : str
  968. The name of the caller that requires matplotlib.
  969. """
  970. try:
  971. import matplotlib # noqa
  972. except ImportError as e:
  973. raise ImportError(
  974. "{} requires matplotlib. You can install matplotlib with "
  975. "`pip install matplotlib`".format(caller_name)
  976. ) from e
  977. def check_pandas_support(caller_name):
  978. """Raise ImportError with detailed error message if pandas is not installed.
  979. Plot utilities like :func:`fetch_openml` should lazily import
  980. pandas and call this helper before any computation.
  981. Parameters
  982. ----------
  983. caller_name : str
  984. The name of the caller that requires pandas.
  985. Returns
  986. -------
  987. pandas
  988. The pandas package.
  989. """
  990. try:
  991. import pandas # noqa
  992. return pandas
  993. except ImportError as e:
  994. raise ImportError("{} requires pandas.".format(caller_name)) from e