validation.py 79 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253
  1. """Utilities for input validation"""
  2. # Authors: Olivier Grisel
  3. # Gael Varoquaux
  4. # Andreas Mueller
  5. # Lars Buitinck
  6. # Alexandre Gramfort
  7. # Nicolas Tresegnie
  8. # Sylvain Marie
  9. # License: BSD 3 clause
  10. import numbers
  11. import operator
  12. import warnings
  13. from contextlib import suppress
  14. from functools import reduce, wraps
  15. from inspect import Parameter, isclass, signature
  16. import joblib
  17. import numpy as np
  18. import scipy.sparse as sp
  19. from .. import get_config as _get_config
  20. from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
  21. from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
  22. from ..utils.fixes import ComplexWarning
  23. from ._isfinite import FiniteStatus, cy_isfinite
  24. from .fixes import _object_dtype_isnan
  25. FLOAT_DTYPES = (np.float64, np.float32, np.float16)
  26. # This function is not used anymore at this moment in the code base but we keep it in
  27. # case that we merge a new public function without kwarg only by mistake, which would
  28. # require a deprecation cycle to fix.
  29. def _deprecate_positional_args(func=None, *, version="1.3"):
  30. """Decorator for methods that issues warnings for positional arguments.
  31. Using the keyword-only argument syntax in pep 3102, arguments after the
  32. * will issue a warning when passed as a positional argument.
  33. Parameters
  34. ----------
  35. func : callable, default=None
  36. Function to check arguments on.
  37. version : callable, default="1.3"
  38. The version when positional arguments will result in error.
  39. """
  40. def _inner_deprecate_positional_args(f):
  41. sig = signature(f)
  42. kwonly_args = []
  43. all_args = []
  44. for name, param in sig.parameters.items():
  45. if param.kind == Parameter.POSITIONAL_OR_KEYWORD:
  46. all_args.append(name)
  47. elif param.kind == Parameter.KEYWORD_ONLY:
  48. kwonly_args.append(name)
  49. @wraps(f)
  50. def inner_f(*args, **kwargs):
  51. extra_args = len(args) - len(all_args)
  52. if extra_args <= 0:
  53. return f(*args, **kwargs)
  54. # extra_args > 0
  55. args_msg = [
  56. "{}={}".format(name, arg)
  57. for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:])
  58. ]
  59. args_msg = ", ".join(args_msg)
  60. warnings.warn(
  61. (
  62. f"Pass {args_msg} as keyword args. From version "
  63. f"{version} passing these as positional arguments "
  64. "will result in an error"
  65. ),
  66. FutureWarning,
  67. )
  68. kwargs.update(zip(sig.parameters, args))
  69. return f(**kwargs)
  70. return inner_f
  71. if func is not None:
  72. return _inner_deprecate_positional_args(func)
  73. return _inner_deprecate_positional_args
  74. def _assert_all_finite(
  75. X, allow_nan=False, msg_dtype=None, estimator_name=None, input_name=""
  76. ):
  77. """Like assert_all_finite, but only for ndarray."""
  78. xp, _ = get_namespace(X)
  79. if _get_config()["assume_finite"]:
  80. return
  81. X = xp.asarray(X)
  82. # for object dtype data, we only check for NaNs (GH-13254)
  83. if X.dtype == np.dtype("object") and not allow_nan:
  84. if _object_dtype_isnan(X).any():
  85. raise ValueError("Input contains NaN")
  86. # We need only consider float arrays, hence can early return for all else.
  87. if not xp.isdtype(X.dtype, ("real floating", "complex floating")):
  88. return
  89. # First try an O(n) time, O(1) space solution for the common case that
  90. # everything is finite; fall back to O(n) space `np.isinf/isnan` or custom
  91. # Cython implementation to prevent false positives and provide a detailed
  92. # error message.
  93. with np.errstate(over="ignore"):
  94. first_pass_isfinite = xp.isfinite(xp.sum(X))
  95. if first_pass_isfinite:
  96. return
  97. _assert_all_finite_element_wise(
  98. X,
  99. xp=xp,
  100. allow_nan=allow_nan,
  101. msg_dtype=msg_dtype,
  102. estimator_name=estimator_name,
  103. input_name=input_name,
  104. )
  105. def _assert_all_finite_element_wise(
  106. X, *, xp, allow_nan, msg_dtype=None, estimator_name=None, input_name=""
  107. ):
  108. # Cython implementation doesn't support FP16 or complex numbers
  109. use_cython = (
  110. xp is np and X.data.contiguous and X.dtype.type in {np.float32, np.float64}
  111. )
  112. if use_cython:
  113. out = cy_isfinite(X.reshape(-1), allow_nan=allow_nan)
  114. has_nan_error = False if allow_nan else out == FiniteStatus.has_nan
  115. has_inf = out == FiniteStatus.has_infinite
  116. else:
  117. has_inf = xp.any(xp.isinf(X))
  118. has_nan_error = False if allow_nan else xp.any(xp.isnan(X))
  119. if has_inf or has_nan_error:
  120. if has_nan_error:
  121. type_err = "NaN"
  122. else:
  123. msg_dtype = msg_dtype if msg_dtype is not None else X.dtype
  124. type_err = f"infinity or a value too large for {msg_dtype!r}"
  125. padded_input_name = input_name + " " if input_name else ""
  126. msg_err = f"Input {padded_input_name}contains {type_err}."
  127. if estimator_name and input_name == "X" and has_nan_error:
  128. # Improve the error message on how to handle missing values in
  129. # scikit-learn.
  130. msg_err += (
  131. f"\n{estimator_name} does not accept missing values"
  132. " encoded as NaN natively. For supervised learning, you might want"
  133. " to consider sklearn.ensemble.HistGradientBoostingClassifier and"
  134. " Regressor which accept missing values encoded as NaNs natively."
  135. " Alternatively, it is possible to preprocess the data, for"
  136. " instance by using an imputer transformer in a pipeline or drop"
  137. " samples with missing values. See"
  138. " https://scikit-learn.org/stable/modules/impute.html"
  139. " You can find a list of all estimators that handle NaN values"
  140. " at the following page:"
  141. " https://scikit-learn.org/stable/modules/impute.html"
  142. "#estimators-that-handle-nan-values"
  143. )
  144. raise ValueError(msg_err)
  145. def assert_all_finite(
  146. X,
  147. *,
  148. allow_nan=False,
  149. estimator_name=None,
  150. input_name="",
  151. ):
  152. """Throw a ValueError if X contains NaN or infinity.
  153. Parameters
  154. ----------
  155. X : {ndarray, sparse matrix}
  156. The input data.
  157. allow_nan : bool, default=False
  158. If True, do not throw error when `X` contains NaN.
  159. estimator_name : str, default=None
  160. The estimator name, used to construct the error message.
  161. input_name : str, default=""
  162. The data name used to construct the error message. In particular
  163. if `input_name` is "X" and the data has NaN values and
  164. allow_nan is False, the error message will link to the imputer
  165. documentation.
  166. """
  167. _assert_all_finite(
  168. X.data if sp.issparse(X) else X,
  169. allow_nan=allow_nan,
  170. estimator_name=estimator_name,
  171. input_name=input_name,
  172. )
  173. def as_float_array(X, *, copy=True, force_all_finite=True):
  174. """Convert an array-like to an array of floats.
  175. The new dtype will be np.float32 or np.float64, depending on the original
  176. type. The function can create a copy or modify the argument depending
  177. on the argument copy.
  178. Parameters
  179. ----------
  180. X : {array-like, sparse matrix}
  181. The input data.
  182. copy : bool, default=True
  183. If True, a copy of X will be created. If False, a copy may still be
  184. returned if X's dtype is not a floating point type.
  185. force_all_finite : bool or 'allow-nan', default=True
  186. Whether to raise an error on np.inf, np.nan, pd.NA in X. The
  187. possibilities are:
  188. - True: Force all values of X to be finite.
  189. - False: accepts np.inf, np.nan, pd.NA in X.
  190. - 'allow-nan': accepts only np.nan and pd.NA values in X. Values cannot
  191. be infinite.
  192. .. versionadded:: 0.20
  193. ``force_all_finite`` accepts the string ``'allow-nan'``.
  194. .. versionchanged:: 0.23
  195. Accepts `pd.NA` and converts it into `np.nan`
  196. Returns
  197. -------
  198. XT : {ndarray, sparse matrix}
  199. An array of type float.
  200. """
  201. if isinstance(X, np.matrix) or (
  202. not isinstance(X, np.ndarray) and not sp.issparse(X)
  203. ):
  204. return check_array(
  205. X,
  206. accept_sparse=["csr", "csc", "coo"],
  207. dtype=np.float64,
  208. copy=copy,
  209. force_all_finite=force_all_finite,
  210. ensure_2d=False,
  211. )
  212. elif sp.issparse(X) and X.dtype in [np.float32, np.float64]:
  213. return X.copy() if copy else X
  214. elif X.dtype in [np.float32, np.float64]: # is numpy array
  215. return X.copy("F" if X.flags["F_CONTIGUOUS"] else "C") if copy else X
  216. else:
  217. if X.dtype.kind in "uib" and X.dtype.itemsize <= 4:
  218. return_dtype = np.float32
  219. else:
  220. return_dtype = np.float64
  221. return X.astype(return_dtype)
  222. def _is_arraylike(x):
  223. """Returns whether the input is array-like."""
  224. return hasattr(x, "__len__") or hasattr(x, "shape") or hasattr(x, "__array__")
  225. def _is_arraylike_not_scalar(array):
  226. """Return True if array is array-like and not a scalar"""
  227. return _is_arraylike(array) and not np.isscalar(array)
  228. def _num_features(X):
  229. """Return the number of features in an array-like X.
  230. This helper function tries hard to avoid to materialize an array version
  231. of X unless necessary. For instance, if X is a list of lists,
  232. this function will return the length of the first element, assuming
  233. that subsequent elements are all lists of the same length without
  234. checking.
  235. Parameters
  236. ----------
  237. X : array-like
  238. array-like to get the number of features.
  239. Returns
  240. -------
  241. features : int
  242. Number of features
  243. """
  244. type_ = type(X)
  245. if type_.__module__ == "builtins":
  246. type_name = type_.__qualname__
  247. else:
  248. type_name = f"{type_.__module__}.{type_.__qualname__}"
  249. message = f"Unable to find the number of features from X of type {type_name}"
  250. if not hasattr(X, "__len__") and not hasattr(X, "shape"):
  251. if not hasattr(X, "__array__"):
  252. raise TypeError(message)
  253. # Only convert X to a numpy array if there is no cheaper, heuristic
  254. # option.
  255. X = np.asarray(X)
  256. if hasattr(X, "shape"):
  257. if not hasattr(X.shape, "__len__") or len(X.shape) <= 1:
  258. message += f" with shape {X.shape}"
  259. raise TypeError(message)
  260. return X.shape[1]
  261. first_sample = X[0]
  262. # Do not consider an array-like of strings or dicts to be a 2D array
  263. if isinstance(first_sample, (str, bytes, dict)):
  264. message += f" where the samples are of type {type(first_sample).__qualname__}"
  265. raise TypeError(message)
  266. try:
  267. # If X is a list of lists, for instance, we assume that all nested
  268. # lists have the same length without checking or converting to
  269. # a numpy array to keep this function call as cheap as possible.
  270. return len(first_sample)
  271. except Exception as err:
  272. raise TypeError(message) from err
  273. def _num_samples(x):
  274. """Return number of samples in array-like x."""
  275. message = "Expected sequence or array-like, got %s" % type(x)
  276. if hasattr(x, "fit") and callable(x.fit):
  277. # Don't get num_samples from an ensembles length!
  278. raise TypeError(message)
  279. if not hasattr(x, "__len__") and not hasattr(x, "shape"):
  280. if hasattr(x, "__array__"):
  281. x = np.asarray(x)
  282. else:
  283. raise TypeError(message)
  284. if hasattr(x, "shape") and x.shape is not None:
  285. if len(x.shape) == 0:
  286. raise TypeError(
  287. "Singleton array %r cannot be considered a valid collection." % x
  288. )
  289. # Check that shape is returning an integer or default to len
  290. # Dask dataframes may not return numeric shape[0] value
  291. if isinstance(x.shape[0], numbers.Integral):
  292. return x.shape[0]
  293. try:
  294. return len(x)
  295. except TypeError as type_error:
  296. raise TypeError(message) from type_error
  297. def check_memory(memory):
  298. """Check that ``memory`` is joblib.Memory-like.
  299. joblib.Memory-like means that ``memory`` can be converted into a
  300. joblib.Memory instance (typically a str denoting the ``location``)
  301. or has the same interface (has a ``cache`` method).
  302. Parameters
  303. ----------
  304. memory : None, str or object with the joblib.Memory interface
  305. - If string, the location where to create the `joblib.Memory` interface.
  306. - If None, no caching is done and the Memory object is completely transparent.
  307. Returns
  308. -------
  309. memory : object with the joblib.Memory interface
  310. A correct joblib.Memory object.
  311. Raises
  312. ------
  313. ValueError
  314. If ``memory`` is not joblib.Memory-like.
  315. """
  316. if memory is None or isinstance(memory, str):
  317. memory = joblib.Memory(location=memory, verbose=0)
  318. elif not hasattr(memory, "cache"):
  319. raise ValueError(
  320. "'memory' should be None, a string or have the same"
  321. " interface as joblib.Memory."
  322. " Got memory='{}' instead.".format(memory)
  323. )
  324. return memory
  325. def check_consistent_length(*arrays):
  326. """Check that all arrays have consistent first dimensions.
  327. Checks whether all objects in arrays have the same shape or length.
  328. Parameters
  329. ----------
  330. *arrays : list or tuple of input objects.
  331. Objects that will be checked for consistent length.
  332. """
  333. lengths = [_num_samples(X) for X in arrays if X is not None]
  334. uniques = np.unique(lengths)
  335. if len(uniques) > 1:
  336. raise ValueError(
  337. "Found input variables with inconsistent numbers of samples: %r"
  338. % [int(l) for l in lengths]
  339. )
  340. def _make_indexable(iterable):
  341. """Ensure iterable supports indexing or convert to an indexable variant.
  342. Convert sparse matrices to csr and other non-indexable iterable to arrays.
  343. Let `None` and indexable objects (e.g. pandas dataframes) pass unchanged.
  344. Parameters
  345. ----------
  346. iterable : {list, dataframe, ndarray, sparse matrix} or None
  347. Object to be converted to an indexable iterable.
  348. """
  349. if sp.issparse(iterable):
  350. return iterable.tocsr()
  351. elif hasattr(iterable, "__getitem__") or hasattr(iterable, "iloc"):
  352. return iterable
  353. elif iterable is None:
  354. return iterable
  355. return np.array(iterable)
  356. def indexable(*iterables):
  357. """Make arrays indexable for cross-validation.
  358. Checks consistent length, passes through None, and ensures that everything
  359. can be indexed by converting sparse matrices to csr and converting
  360. non-interable objects to arrays.
  361. Parameters
  362. ----------
  363. *iterables : {lists, dataframes, ndarrays, sparse matrices}
  364. List of objects to ensure sliceability.
  365. Returns
  366. -------
  367. result : list of {ndarray, sparse matrix, dataframe} or None
  368. Returns a list containing indexable arrays (i.e. NumPy array,
  369. sparse matrix, or dataframe) or `None`.
  370. """
  371. result = [_make_indexable(X) for X in iterables]
  372. check_consistent_length(*result)
  373. return result
  374. def _ensure_sparse_format(
  375. spmatrix,
  376. accept_sparse,
  377. dtype,
  378. copy,
  379. force_all_finite,
  380. accept_large_sparse,
  381. estimator_name=None,
  382. input_name="",
  383. ):
  384. """Convert a sparse matrix to a given format.
  385. Checks the sparse format of spmatrix and converts if necessary.
  386. Parameters
  387. ----------
  388. spmatrix : sparse matrix
  389. Input to validate and convert.
  390. accept_sparse : str, bool or list/tuple of str
  391. String[s] representing allowed sparse matrix formats ('csc',
  392. 'csr', 'coo', 'dok', 'bsr', 'lil', 'dia'). If the input is sparse but
  393. not in the allowed format, it will be converted to the first listed
  394. format. True allows the input to be any format. False means
  395. that a sparse matrix input will raise an error.
  396. dtype : str, type or None
  397. Data type of result. If None, the dtype of the input is preserved.
  398. copy : bool
  399. Whether a forced copy will be triggered. If copy=False, a copy might
  400. be triggered by a conversion.
  401. force_all_finite : bool or 'allow-nan'
  402. Whether to raise an error on np.inf, np.nan, pd.NA in X. The
  403. possibilities are:
  404. - True: Force all values of X to be finite.
  405. - False: accepts np.inf, np.nan, pd.NA in X.
  406. - 'allow-nan': accepts only np.nan and pd.NA values in X. Values cannot
  407. be infinite.
  408. .. versionadded:: 0.20
  409. ``force_all_finite`` accepts the string ``'allow-nan'``.
  410. .. versionchanged:: 0.23
  411. Accepts `pd.NA` and converts it into `np.nan`
  412. estimator_name : str, default=None
  413. The estimator name, used to construct the error message.
  414. input_name : str, default=""
  415. The data name used to construct the error message. In particular
  416. if `input_name` is "X" and the data has NaN values and
  417. allow_nan is False, the error message will link to the imputer
  418. documentation.
  419. Returns
  420. -------
  421. spmatrix_converted : sparse matrix.
  422. Matrix that is ensured to have an allowed type.
  423. """
  424. if dtype is None:
  425. dtype = spmatrix.dtype
  426. changed_format = False
  427. if isinstance(accept_sparse, str):
  428. accept_sparse = [accept_sparse]
  429. # Indices dtype validation
  430. _check_large_sparse(spmatrix, accept_large_sparse)
  431. if accept_sparse is False:
  432. raise TypeError(
  433. "A sparse matrix was passed, but dense "
  434. "data is required. Use X.toarray() to "
  435. "convert to a dense numpy array."
  436. )
  437. elif isinstance(accept_sparse, (list, tuple)):
  438. if len(accept_sparse) == 0:
  439. raise ValueError(
  440. "When providing 'accept_sparse' "
  441. "as a tuple or list, it must contain at "
  442. "least one string value."
  443. )
  444. # ensure correct sparse format
  445. if spmatrix.format not in accept_sparse:
  446. # create new with correct sparse
  447. spmatrix = spmatrix.asformat(accept_sparse[0])
  448. changed_format = True
  449. elif accept_sparse is not True:
  450. # any other type
  451. raise ValueError(
  452. "Parameter 'accept_sparse' should be a string, "
  453. "boolean or list of strings. You provided "
  454. "'accept_sparse={}'.".format(accept_sparse)
  455. )
  456. if dtype != spmatrix.dtype:
  457. # convert dtype
  458. spmatrix = spmatrix.astype(dtype)
  459. elif copy and not changed_format:
  460. # force copy
  461. spmatrix = spmatrix.copy()
  462. if force_all_finite:
  463. if not hasattr(spmatrix, "data"):
  464. warnings.warn(
  465. "Can't check %s sparse matrix for nan or inf." % spmatrix.format,
  466. stacklevel=2,
  467. )
  468. else:
  469. _assert_all_finite(
  470. spmatrix.data,
  471. allow_nan=force_all_finite == "allow-nan",
  472. estimator_name=estimator_name,
  473. input_name=input_name,
  474. )
  475. return spmatrix
  476. def _ensure_no_complex_data(array):
  477. if (
  478. hasattr(array, "dtype")
  479. and array.dtype is not None
  480. and hasattr(array.dtype, "kind")
  481. and array.dtype.kind == "c"
  482. ):
  483. raise ValueError("Complex data not supported\n{}\n".format(array))
  484. def _check_estimator_name(estimator):
  485. if estimator is not None:
  486. if isinstance(estimator, str):
  487. return estimator
  488. else:
  489. return estimator.__class__.__name__
  490. return None
  491. def _pandas_dtype_needs_early_conversion(pd_dtype):
  492. """Return True if pandas extension pd_dtype need to be converted early."""
  493. # Check these early for pandas versions without extension dtypes
  494. from pandas import SparseDtype
  495. from pandas.api.types import (
  496. is_bool_dtype,
  497. is_float_dtype,
  498. is_integer_dtype,
  499. )
  500. if is_bool_dtype(pd_dtype):
  501. # bool and extension booleans need early conversion because __array__
  502. # converts mixed dtype dataframes into object dtypes
  503. return True
  504. if isinstance(pd_dtype, SparseDtype):
  505. # Sparse arrays will be converted later in `check_array`
  506. return False
  507. try:
  508. from pandas.api.types import is_extension_array_dtype
  509. except ImportError:
  510. return False
  511. if isinstance(pd_dtype, SparseDtype) or not is_extension_array_dtype(pd_dtype):
  512. # Sparse arrays will be converted later in `check_array`
  513. # Only handle extension arrays for integer and floats
  514. return False
  515. elif is_float_dtype(pd_dtype):
  516. # Float ndarrays can normally support nans. They need to be converted
  517. # first to map pd.NA to np.nan
  518. return True
  519. elif is_integer_dtype(pd_dtype):
  520. # XXX: Warn when converting from a high integer to a float
  521. return True
  522. return False
  523. def _is_extension_array_dtype(array):
  524. # Pandas extension arrays have a dtype with an na_value
  525. return hasattr(array, "dtype") and hasattr(array.dtype, "na_value")
  526. def check_array(
  527. array,
  528. accept_sparse=False,
  529. *,
  530. accept_large_sparse=True,
  531. dtype="numeric",
  532. order=None,
  533. copy=False,
  534. force_all_finite=True,
  535. ensure_2d=True,
  536. allow_nd=False,
  537. ensure_min_samples=1,
  538. ensure_min_features=1,
  539. estimator=None,
  540. input_name="",
  541. ):
  542. """Input validation on an array, list, sparse matrix or similar.
  543. By default, the input is checked to be a non-empty 2D array containing
  544. only finite values. If the dtype of the array is object, attempt
  545. converting to float, raising on failure.
  546. Parameters
  547. ----------
  548. array : object
  549. Input object to check / convert.
  550. accept_sparse : str, bool or list/tuple of str, default=False
  551. String[s] representing allowed sparse matrix formats, such as 'csc',
  552. 'csr', etc. If the input is sparse but not in the allowed format,
  553. it will be converted to the first listed format. True allows the input
  554. to be any format. False means that a sparse matrix input will
  555. raise an error.
  556. accept_large_sparse : bool, default=True
  557. If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by
  558. accept_sparse, accept_large_sparse=False will cause it to be accepted
  559. only if its indices are stored with a 32-bit dtype.
  560. .. versionadded:: 0.20
  561. dtype : 'numeric', type, list of type or None, default='numeric'
  562. Data type of result. If None, the dtype of the input is preserved.
  563. If "numeric", dtype is preserved unless array.dtype is object.
  564. If dtype is a list of types, conversion on the first type is only
  565. performed if the dtype of the input is not in the list.
  566. order : {'F', 'C'} or None, default=None
  567. Whether an array will be forced to be fortran or c-style.
  568. When order is None (default), then if copy=False, nothing is ensured
  569. about the memory layout of the output array; otherwise (copy=True)
  570. the memory layout of the returned array is kept as close as possible
  571. to the original array.
  572. copy : bool, default=False
  573. Whether a forced copy will be triggered. If copy=False, a copy might
  574. be triggered by a conversion.
  575. force_all_finite : bool or 'allow-nan', default=True
  576. Whether to raise an error on np.inf, np.nan, pd.NA in array. The
  577. possibilities are:
  578. - True: Force all values of array to be finite.
  579. - False: accepts np.inf, np.nan, pd.NA in array.
  580. - 'allow-nan': accepts only np.nan and pd.NA values in array. Values
  581. cannot be infinite.
  582. .. versionadded:: 0.20
  583. ``force_all_finite`` accepts the string ``'allow-nan'``.
  584. .. versionchanged:: 0.23
  585. Accepts `pd.NA` and converts it into `np.nan`
  586. ensure_2d : bool, default=True
  587. Whether to raise a value error if array is not 2D.
  588. allow_nd : bool, default=False
  589. Whether to allow array.ndim > 2.
  590. ensure_min_samples : int, default=1
  591. Make sure that the array has a minimum number of samples in its first
  592. axis (rows for a 2D array). Setting to 0 disables this check.
  593. ensure_min_features : int, default=1
  594. Make sure that the 2D array has some minimum number of features
  595. (columns). The default value of 1 rejects empty datasets.
  596. This check is only enforced when the input data has effectively 2
  597. dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0
  598. disables this check.
  599. estimator : str or estimator instance, default=None
  600. If passed, include the name of the estimator in warning messages.
  601. input_name : str, default=""
  602. The data name used to construct the error message. In particular
  603. if `input_name` is "X" and the data has NaN values and
  604. allow_nan is False, the error message will link to the imputer
  605. documentation.
  606. .. versionadded:: 1.1.0
  607. Returns
  608. -------
  609. array_converted : object
  610. The converted and validated array.
  611. """
  612. if isinstance(array, np.matrix):
  613. raise TypeError(
  614. "np.matrix is not supported. Please convert to a numpy array with "
  615. "np.asarray. For more information see: "
  616. "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html"
  617. )
  618. xp, is_array_api_compliant = get_namespace(array)
  619. # store reference to original array to check if copy is needed when
  620. # function returns
  621. array_orig = array
  622. # store whether originally we wanted numeric dtype
  623. dtype_numeric = isinstance(dtype, str) and dtype == "numeric"
  624. dtype_orig = getattr(array, "dtype", None)
  625. if not is_array_api_compliant and not hasattr(dtype_orig, "kind"):
  626. # not a data type (e.g. a column named dtype in a pandas DataFrame)
  627. dtype_orig = None
  628. # check if the object contains several dtypes (typically a pandas
  629. # DataFrame), and store them. If not, store None.
  630. dtypes_orig = None
  631. pandas_requires_conversion = False
  632. if hasattr(array, "dtypes") and hasattr(array.dtypes, "__array__"):
  633. # throw warning if columns are sparse. If all columns are sparse, then
  634. # array.sparse exists and sparsity will be preserved (later).
  635. with suppress(ImportError):
  636. from pandas import SparseDtype
  637. def is_sparse(dtype):
  638. return isinstance(dtype, SparseDtype)
  639. if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  640. warnings.warn(
  641. "pandas.DataFrame with sparse columns found."
  642. "It will be converted to a dense numpy array."
  643. )
  644. dtypes_orig = list(array.dtypes)
  645. pandas_requires_conversion = any(
  646. _pandas_dtype_needs_early_conversion(i) for i in dtypes_orig
  647. )
  648. if all(isinstance(dtype_iter, np.dtype) for dtype_iter in dtypes_orig):
  649. dtype_orig = np.result_type(*dtypes_orig)
  650. elif pandas_requires_conversion and any(d == object for d in dtypes_orig):
  651. # Force object if any of the dtypes is an object
  652. dtype_orig = object
  653. elif (_is_extension_array_dtype(array) or hasattr(array, "iloc")) and hasattr(
  654. array, "dtype"
  655. ):
  656. # array is a pandas series
  657. pandas_requires_conversion = _pandas_dtype_needs_early_conversion(array.dtype)
  658. if isinstance(array.dtype, np.dtype):
  659. dtype_orig = array.dtype
  660. else:
  661. # Set to None to let array.astype work out the best dtype
  662. dtype_orig = None
  663. if dtype_numeric:
  664. if (
  665. dtype_orig is not None
  666. and hasattr(dtype_orig, "kind")
  667. and dtype_orig.kind == "O"
  668. ):
  669. # if input is object, convert to float.
  670. dtype = xp.float64
  671. else:
  672. dtype = None
  673. if isinstance(dtype, (list, tuple)):
  674. if dtype_orig is not None and dtype_orig in dtype:
  675. # no dtype conversion required
  676. dtype = None
  677. else:
  678. # dtype conversion required. Let's select the first element of the
  679. # list of accepted types.
  680. dtype = dtype[0]
  681. if pandas_requires_conversion:
  682. # pandas dataframe requires conversion earlier to handle extension dtypes with
  683. # nans
  684. # Use the original dtype for conversion if dtype is None
  685. new_dtype = dtype_orig if dtype is None else dtype
  686. array = array.astype(new_dtype)
  687. # Since we converted here, we do not need to convert again later
  688. dtype = None
  689. if dtype is not None and _is_numpy_namespace(xp):
  690. dtype = np.dtype(dtype)
  691. if force_all_finite not in (True, False, "allow-nan"):
  692. raise ValueError(
  693. 'force_all_finite should be a bool or "allow-nan". Got {!r} instead'.format(
  694. force_all_finite
  695. )
  696. )
  697. if dtype is not None and _is_numpy_namespace(xp):
  698. # convert to dtype object to conform to Array API to be use `xp.isdtype` later
  699. dtype = np.dtype(dtype)
  700. estimator_name = _check_estimator_name(estimator)
  701. context = " by %s" % estimator_name if estimator is not None else ""
  702. # When all dataframe columns are sparse, convert to a sparse array
  703. if hasattr(array, "sparse") and array.ndim > 1:
  704. with suppress(ImportError):
  705. from pandas import SparseDtype # noqa: F811
  706. def is_sparse(dtype):
  707. return isinstance(dtype, SparseDtype)
  708. if array.dtypes.apply(is_sparse).all():
  709. # DataFrame.sparse only supports `to_coo`
  710. array = array.sparse.to_coo()
  711. if array.dtype == np.dtype("object"):
  712. unique_dtypes = set([dt.subtype.name for dt in array_orig.dtypes])
  713. if len(unique_dtypes) > 1:
  714. raise ValueError(
  715. "Pandas DataFrame with mixed sparse extension arrays "
  716. "generated a sparse matrix with object dtype which "
  717. "can not be converted to a scipy sparse matrix."
  718. "Sparse extension arrays should all have the same "
  719. "numeric type."
  720. )
  721. if sp.issparse(array):
  722. _ensure_no_complex_data(array)
  723. array = _ensure_sparse_format(
  724. array,
  725. accept_sparse=accept_sparse,
  726. dtype=dtype,
  727. copy=copy,
  728. force_all_finite=force_all_finite,
  729. accept_large_sparse=accept_large_sparse,
  730. estimator_name=estimator_name,
  731. input_name=input_name,
  732. )
  733. else:
  734. # If np.array(..) gives ComplexWarning, then we convert the warning
  735. # to an error. This is needed because specifying a non complex
  736. # dtype to the function converts complex to real dtype,
  737. # thereby passing the test made in the lines following the scope
  738. # of warnings context manager.
  739. with warnings.catch_warnings():
  740. try:
  741. warnings.simplefilter("error", ComplexWarning)
  742. if dtype is not None and xp.isdtype(dtype, "integral"):
  743. # Conversion float -> int should not contain NaN or
  744. # inf (numpy#14412). We cannot use casting='safe' because
  745. # then conversion float -> int would be disallowed.
  746. array = _asarray_with_order(array, order=order, xp=xp)
  747. if xp.isdtype(array.dtype, ("real floating", "complex floating")):
  748. _assert_all_finite(
  749. array,
  750. allow_nan=False,
  751. msg_dtype=dtype,
  752. estimator_name=estimator_name,
  753. input_name=input_name,
  754. )
  755. array = xp.astype(array, dtype, copy=False)
  756. else:
  757. array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
  758. except ComplexWarning as complex_warning:
  759. raise ValueError(
  760. "Complex data not supported\n{}\n".format(array)
  761. ) from complex_warning
  762. # It is possible that the np.array(..) gave no warning. This happens
  763. # when no dtype conversion happened, for example dtype = None. The
  764. # result is that np.array(..) produces an array of complex dtype
  765. # and we need to catch and raise exception for such cases.
  766. _ensure_no_complex_data(array)
  767. if ensure_2d:
  768. # If input is scalar raise error
  769. if array.ndim == 0:
  770. raise ValueError(
  771. "Expected 2D array, got scalar array instead:\narray={}.\n"
  772. "Reshape your data either using array.reshape(-1, 1) if "
  773. "your data has a single feature or array.reshape(1, -1) "
  774. "if it contains a single sample.".format(array)
  775. )
  776. # If input is 1D raise error
  777. if array.ndim == 1:
  778. raise ValueError(
  779. "Expected 2D array, got 1D array instead:\narray={}.\n"
  780. "Reshape your data either using array.reshape(-1, 1) if "
  781. "your data has a single feature or array.reshape(1, -1) "
  782. "if it contains a single sample.".format(array)
  783. )
  784. if dtype_numeric and hasattr(array.dtype, "kind") and array.dtype.kind in "USV":
  785. raise ValueError(
  786. "dtype='numeric' is not compatible with arrays of bytes/strings."
  787. "Convert your data to numeric values explicitly instead."
  788. )
  789. if not allow_nd and array.ndim >= 3:
  790. raise ValueError(
  791. "Found array with dim %d. %s expected <= 2."
  792. % (array.ndim, estimator_name)
  793. )
  794. if force_all_finite:
  795. _assert_all_finite(
  796. array,
  797. input_name=input_name,
  798. estimator_name=estimator_name,
  799. allow_nan=force_all_finite == "allow-nan",
  800. )
  801. if ensure_min_samples > 0:
  802. n_samples = _num_samples(array)
  803. if n_samples < ensure_min_samples:
  804. raise ValueError(
  805. "Found array with %d sample(s) (shape=%s) while a"
  806. " minimum of %d is required%s."
  807. % (n_samples, array.shape, ensure_min_samples, context)
  808. )
  809. if ensure_min_features > 0 and array.ndim == 2:
  810. n_features = array.shape[1]
  811. if n_features < ensure_min_features:
  812. raise ValueError(
  813. "Found array with %d feature(s) (shape=%s) while"
  814. " a minimum of %d is required%s."
  815. % (n_features, array.shape, ensure_min_features, context)
  816. )
  817. if copy:
  818. if _is_numpy_namespace(xp):
  819. # only make a copy if `array` and `array_orig` may share memory`
  820. if np.may_share_memory(array, array_orig):
  821. array = _asarray_with_order(
  822. array, dtype=dtype, order=order, copy=True, xp=xp
  823. )
  824. else:
  825. # always make a copy for non-numpy arrays
  826. array = _asarray_with_order(
  827. array, dtype=dtype, order=order, copy=True, xp=xp
  828. )
  829. return array
  830. def _check_large_sparse(X, accept_large_sparse=False):
  831. """Raise a ValueError if X has 64bit indices and accept_large_sparse=False"""
  832. if not accept_large_sparse:
  833. supported_indices = ["int32"]
  834. if X.getformat() == "coo":
  835. index_keys = ["col", "row"]
  836. elif X.getformat() in ["csr", "csc", "bsr"]:
  837. index_keys = ["indices", "indptr"]
  838. else:
  839. return
  840. for key in index_keys:
  841. indices_datatype = getattr(X, key).dtype
  842. if indices_datatype not in supported_indices:
  843. raise ValueError(
  844. "Only sparse matrices with 32-bit integer indices are accepted."
  845. f" Got {indices_datatype} indices. Please do report a minimal"
  846. " reproducer on scikit-learn issue tracker so that support for"
  847. " your use-case can be studied by maintainers. See:"
  848. " https://scikit-learn.org/dev/developers/minimal_reproducer.html"
  849. )
  850. def check_X_y(
  851. X,
  852. y,
  853. accept_sparse=False,
  854. *,
  855. accept_large_sparse=True,
  856. dtype="numeric",
  857. order=None,
  858. copy=False,
  859. force_all_finite=True,
  860. ensure_2d=True,
  861. allow_nd=False,
  862. multi_output=False,
  863. ensure_min_samples=1,
  864. ensure_min_features=1,
  865. y_numeric=False,
  866. estimator=None,
  867. ):
  868. """Input validation for standard estimators.
  869. Checks X and y for consistent length, enforces X to be 2D and y 1D. By
  870. default, X is checked to be non-empty and containing only finite values.
  871. Standard input checks are also applied to y, such as checking that y
  872. does not have np.nan or np.inf targets. For multi-label y, set
  873. multi_output=True to allow 2D and sparse y. If the dtype of X is
  874. object, attempt converting to float, raising on failure.
  875. Parameters
  876. ----------
  877. X : {ndarray, list, sparse matrix}
  878. Input data.
  879. y : {ndarray, list, sparse matrix}
  880. Labels.
  881. accept_sparse : str, bool or list of str, default=False
  882. String[s] representing allowed sparse matrix formats, such as 'csc',
  883. 'csr', etc. If the input is sparse but not in the allowed format,
  884. it will be converted to the first listed format. True allows the input
  885. to be any format. False means that a sparse matrix input will
  886. raise an error.
  887. accept_large_sparse : bool, default=True
  888. If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by
  889. accept_sparse, accept_large_sparse will cause it to be accepted only
  890. if its indices are stored with a 32-bit dtype.
  891. .. versionadded:: 0.20
  892. dtype : 'numeric', type, list of type or None, default='numeric'
  893. Data type of result. If None, the dtype of the input is preserved.
  894. If "numeric", dtype is preserved unless array.dtype is object.
  895. If dtype is a list of types, conversion on the first type is only
  896. performed if the dtype of the input is not in the list.
  897. order : {'F', 'C'}, default=None
  898. Whether an array will be forced to be fortran or c-style. If
  899. `None`, then the input data's order is preserved when possible.
  900. copy : bool, default=False
  901. Whether a forced copy will be triggered. If copy=False, a copy might
  902. be triggered by a conversion.
  903. force_all_finite : bool or 'allow-nan', default=True
  904. Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter
  905. does not influence whether y can have np.inf, np.nan, pd.NA values.
  906. The possibilities are:
  907. - True: Force all values of X to be finite.
  908. - False: accepts np.inf, np.nan, pd.NA in X.
  909. - 'allow-nan': accepts only np.nan or pd.NA values in X. Values cannot
  910. be infinite.
  911. .. versionadded:: 0.20
  912. ``force_all_finite`` accepts the string ``'allow-nan'``.
  913. .. versionchanged:: 0.23
  914. Accepts `pd.NA` and converts it into `np.nan`
  915. ensure_2d : bool, default=True
  916. Whether to raise a value error if X is not 2D.
  917. allow_nd : bool, default=False
  918. Whether to allow X.ndim > 2.
  919. multi_output : bool, default=False
  920. Whether to allow 2D y (array or sparse matrix). If false, y will be
  921. validated as a vector. y cannot have np.nan or np.inf values if
  922. multi_output=True.
  923. ensure_min_samples : int, default=1
  924. Make sure that X has a minimum number of samples in its first
  925. axis (rows for a 2D array).
  926. ensure_min_features : int, default=1
  927. Make sure that the 2D array has some minimum number of features
  928. (columns). The default value of 1 rejects empty datasets.
  929. This check is only enforced when X has effectively 2 dimensions or
  930. is originally 1D and ``ensure_2d`` is True. Setting to 0 disables
  931. this check.
  932. y_numeric : bool, default=False
  933. Whether to ensure that y has a numeric type. If dtype of y is object,
  934. it is converted to float64. Should only be used for regression
  935. algorithms.
  936. estimator : str or estimator instance, default=None
  937. If passed, include the name of the estimator in warning messages.
  938. Returns
  939. -------
  940. X_converted : object
  941. The converted and validated X.
  942. y_converted : object
  943. The converted and validated y.
  944. """
  945. if y is None:
  946. if estimator is None:
  947. estimator_name = "estimator"
  948. else:
  949. estimator_name = _check_estimator_name(estimator)
  950. raise ValueError(
  951. f"{estimator_name} requires y to be passed, but the target y is None"
  952. )
  953. X = check_array(
  954. X,
  955. accept_sparse=accept_sparse,
  956. accept_large_sparse=accept_large_sparse,
  957. dtype=dtype,
  958. order=order,
  959. copy=copy,
  960. force_all_finite=force_all_finite,
  961. ensure_2d=ensure_2d,
  962. allow_nd=allow_nd,
  963. ensure_min_samples=ensure_min_samples,
  964. ensure_min_features=ensure_min_features,
  965. estimator=estimator,
  966. input_name="X",
  967. )
  968. y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric, estimator=estimator)
  969. check_consistent_length(X, y)
  970. return X, y
  971. def _check_y(y, multi_output=False, y_numeric=False, estimator=None):
  972. """Isolated part of check_X_y dedicated to y validation"""
  973. if multi_output:
  974. y = check_array(
  975. y,
  976. accept_sparse="csr",
  977. force_all_finite=True,
  978. ensure_2d=False,
  979. dtype=None,
  980. input_name="y",
  981. estimator=estimator,
  982. )
  983. else:
  984. estimator_name = _check_estimator_name(estimator)
  985. y = column_or_1d(y, warn=True)
  986. _assert_all_finite(y, input_name="y", estimator_name=estimator_name)
  987. _ensure_no_complex_data(y)
  988. if y_numeric and y.dtype.kind == "O":
  989. y = y.astype(np.float64)
  990. return y
  991. def column_or_1d(y, *, dtype=None, warn=False):
  992. """Ravel column or 1d numpy array, else raises an error.
  993. Parameters
  994. ----------
  995. y : array-like
  996. Input data.
  997. dtype : data-type, default=None
  998. Data type for `y`.
  999. .. versionadded:: 1.2
  1000. warn : bool, default=False
  1001. To control display of warnings.
  1002. Returns
  1003. -------
  1004. y : ndarray
  1005. Output data.
  1006. Raises
  1007. ------
  1008. ValueError
  1009. If `y` is not a 1D array or a 2D array with a single row or column.
  1010. """
  1011. xp, _ = get_namespace(y)
  1012. y = check_array(
  1013. y,
  1014. ensure_2d=False,
  1015. dtype=dtype,
  1016. input_name="y",
  1017. force_all_finite=False,
  1018. ensure_min_samples=0,
  1019. )
  1020. shape = y.shape
  1021. if len(shape) == 1:
  1022. return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp)
  1023. if len(shape) == 2 and shape[1] == 1:
  1024. if warn:
  1025. warnings.warn(
  1026. (
  1027. "A column-vector y was passed when a 1d array was"
  1028. " expected. Please change the shape of y to "
  1029. "(n_samples, ), for example using ravel()."
  1030. ),
  1031. DataConversionWarning,
  1032. stacklevel=2,
  1033. )
  1034. return _asarray_with_order(xp.reshape(y, (-1,)), order="C", xp=xp)
  1035. raise ValueError(
  1036. "y should be a 1d array, got an array of shape {} instead.".format(shape)
  1037. )
  1038. def check_random_state(seed):
  1039. """Turn seed into a np.random.RandomState instance.
  1040. Parameters
  1041. ----------
  1042. seed : None, int or instance of RandomState
  1043. If seed is None, return the RandomState singleton used by np.random.
  1044. If seed is an int, return a new RandomState instance seeded with seed.
  1045. If seed is already a RandomState instance, return it.
  1046. Otherwise raise ValueError.
  1047. Returns
  1048. -------
  1049. :class:`numpy:numpy.random.RandomState`
  1050. The random state object based on `seed` parameter.
  1051. """
  1052. if seed is None or seed is np.random:
  1053. return np.random.mtrand._rand
  1054. if isinstance(seed, numbers.Integral):
  1055. return np.random.RandomState(seed)
  1056. if isinstance(seed, np.random.RandomState):
  1057. return seed
  1058. raise ValueError(
  1059. "%r cannot be used to seed a numpy.random.RandomState instance" % seed
  1060. )
  1061. def has_fit_parameter(estimator, parameter):
  1062. """Check whether the estimator's fit method supports the given parameter.
  1063. Parameters
  1064. ----------
  1065. estimator : object
  1066. An estimator to inspect.
  1067. parameter : str
  1068. The searched parameter.
  1069. Returns
  1070. -------
  1071. is_parameter : bool
  1072. Whether the parameter was found to be a named parameter of the
  1073. estimator's fit method.
  1074. Examples
  1075. --------
  1076. >>> from sklearn.svm import SVC
  1077. >>> from sklearn.utils.validation import has_fit_parameter
  1078. >>> has_fit_parameter(SVC(), "sample_weight")
  1079. True
  1080. """
  1081. return parameter in signature(estimator.fit).parameters
  1082. def check_symmetric(array, *, tol=1e-10, raise_warning=True, raise_exception=False):
  1083. """Make sure that array is 2D, square and symmetric.
  1084. If the array is not symmetric, then a symmetrized version is returned.
  1085. Optionally, a warning or exception is raised if the matrix is not
  1086. symmetric.
  1087. Parameters
  1088. ----------
  1089. array : {ndarray, sparse matrix}
  1090. Input object to check / convert. Must be two-dimensional and square,
  1091. otherwise a ValueError will be raised.
  1092. tol : float, default=1e-10
  1093. Absolute tolerance for equivalence of arrays. Default = 1E-10.
  1094. raise_warning : bool, default=True
  1095. If True then raise a warning if conversion is required.
  1096. raise_exception : bool, default=False
  1097. If True then raise an exception if array is not symmetric.
  1098. Returns
  1099. -------
  1100. array_sym : {ndarray, sparse matrix}
  1101. Symmetrized version of the input array, i.e. the average of array
  1102. and array.transpose(). If sparse, then duplicate entries are first
  1103. summed and zeros are eliminated.
  1104. """
  1105. if (array.ndim != 2) or (array.shape[0] != array.shape[1]):
  1106. raise ValueError(
  1107. "array must be 2-dimensional and square. shape = {0}".format(array.shape)
  1108. )
  1109. if sp.issparse(array):
  1110. diff = array - array.T
  1111. # only csr, csc, and coo have `data` attribute
  1112. if diff.format not in ["csr", "csc", "coo"]:
  1113. diff = diff.tocsr()
  1114. symmetric = np.all(abs(diff.data) < tol)
  1115. else:
  1116. symmetric = np.allclose(array, array.T, atol=tol)
  1117. if not symmetric:
  1118. if raise_exception:
  1119. raise ValueError("Array must be symmetric")
  1120. if raise_warning:
  1121. warnings.warn(
  1122. (
  1123. "Array is not symmetric, and will be converted "
  1124. "to symmetric by average with its transpose."
  1125. ),
  1126. stacklevel=2,
  1127. )
  1128. if sp.issparse(array):
  1129. conversion = "to" + array.format
  1130. array = getattr(0.5 * (array + array.T), conversion)()
  1131. else:
  1132. array = 0.5 * (array + array.T)
  1133. return array
  1134. def _is_fitted(estimator, attributes=None, all_or_any=all):
  1135. """Determine if an estimator is fitted
  1136. Parameters
  1137. ----------
  1138. estimator : estimator instance
  1139. Estimator instance for which the check is performed.
  1140. attributes : str, list or tuple of str, default=None
  1141. Attribute name(s) given as string or a list/tuple of strings
  1142. Eg.: ``["coef_", "estimator_", ...], "coef_"``
  1143. If `None`, `estimator` is considered fitted if there exist an
  1144. attribute that ends with a underscore and does not start with double
  1145. underscore.
  1146. all_or_any : callable, {all, any}, default=all
  1147. Specify whether all or any of the given attributes must exist.
  1148. Returns
  1149. -------
  1150. fitted : bool
  1151. Whether the estimator is fitted.
  1152. """
  1153. if attributes is not None:
  1154. if not isinstance(attributes, (list, tuple)):
  1155. attributes = [attributes]
  1156. return all_or_any([hasattr(estimator, attr) for attr in attributes])
  1157. if hasattr(estimator, "__sklearn_is_fitted__"):
  1158. return estimator.__sklearn_is_fitted__()
  1159. fitted_attrs = [
  1160. v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
  1161. ]
  1162. return len(fitted_attrs) > 0
  1163. def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all):
  1164. """Perform is_fitted validation for estimator.
  1165. Checks if the estimator is fitted by verifying the presence of
  1166. fitted attributes (ending with a trailing underscore) and otherwise
  1167. raises a NotFittedError with the given message.
  1168. If an estimator does not set any attributes with a trailing underscore, it
  1169. can define a ``__sklearn_is_fitted__`` method returning a boolean to specify if the
  1170. estimator is fitted or not.
  1171. Parameters
  1172. ----------
  1173. estimator : estimator instance
  1174. Estimator instance for which the check is performed.
  1175. attributes : str, list or tuple of str, default=None
  1176. Attribute name(s) given as string or a list/tuple of strings
  1177. Eg.: ``["coef_", "estimator_", ...], "coef_"``
  1178. If `None`, `estimator` is considered fitted if there exist an
  1179. attribute that ends with a underscore and does not start with double
  1180. underscore.
  1181. msg : str, default=None
  1182. The default error message is, "This %(name)s instance is not fitted
  1183. yet. Call 'fit' with appropriate arguments before using this
  1184. estimator."
  1185. For custom messages if "%(name)s" is present in the message string,
  1186. it is substituted for the estimator name.
  1187. Eg. : "Estimator, %(name)s, must be fitted before sparsifying".
  1188. all_or_any : callable, {all, any}, default=all
  1189. Specify whether all or any of the given attributes must exist.
  1190. Raises
  1191. ------
  1192. TypeError
  1193. If the estimator is a class or not an estimator instance
  1194. NotFittedError
  1195. If the attributes are not found.
  1196. """
  1197. if isclass(estimator):
  1198. raise TypeError("{} is a class, not an instance.".format(estimator))
  1199. if msg is None:
  1200. msg = (
  1201. "This %(name)s instance is not fitted yet. Call 'fit' with "
  1202. "appropriate arguments before using this estimator."
  1203. )
  1204. if not hasattr(estimator, "fit"):
  1205. raise TypeError("%s is not an estimator instance." % (estimator))
  1206. if not _is_fitted(estimator, attributes, all_or_any):
  1207. raise NotFittedError(msg % {"name": type(estimator).__name__})
  1208. def check_non_negative(X, whom):
  1209. """
  1210. Check if there is any negative value in an array.
  1211. Parameters
  1212. ----------
  1213. X : {array-like, sparse matrix}
  1214. Input data.
  1215. whom : str
  1216. Who passed X to this function.
  1217. """
  1218. xp, _ = get_namespace(X)
  1219. # avoid X.min() on sparse matrix since it also sorts the indices
  1220. if sp.issparse(X):
  1221. if X.format in ["lil", "dok"]:
  1222. X = X.tocsr()
  1223. if X.data.size == 0:
  1224. X_min = 0
  1225. else:
  1226. X_min = X.data.min()
  1227. else:
  1228. X_min = xp.min(X)
  1229. if X_min < 0:
  1230. raise ValueError("Negative values in data passed to %s" % whom)
  1231. def check_scalar(
  1232. x,
  1233. name,
  1234. target_type,
  1235. *,
  1236. min_val=None,
  1237. max_val=None,
  1238. include_boundaries="both",
  1239. ):
  1240. """Validate scalar parameters type and value.
  1241. Parameters
  1242. ----------
  1243. x : object
  1244. The scalar parameter to validate.
  1245. name : str
  1246. The name of the parameter to be printed in error messages.
  1247. target_type : type or tuple
  1248. Acceptable data types for the parameter.
  1249. min_val : float or int, default=None
  1250. The minimum valid value the parameter can take. If None (default) it
  1251. is implied that the parameter does not have a lower bound.
  1252. max_val : float or int, default=None
  1253. The maximum valid value the parameter can take. If None (default) it
  1254. is implied that the parameter does not have an upper bound.
  1255. include_boundaries : {"left", "right", "both", "neither"}, default="both"
  1256. Whether the interval defined by `min_val` and `max_val` should include
  1257. the boundaries. Possible choices are:
  1258. - `"left"`: only `min_val` is included in the valid interval.
  1259. It is equivalent to the interval `[ min_val, max_val )`.
  1260. - `"right"`: only `max_val` is included in the valid interval.
  1261. It is equivalent to the interval `( min_val, max_val ]`.
  1262. - `"both"`: `min_val` and `max_val` are included in the valid interval.
  1263. It is equivalent to the interval `[ min_val, max_val ]`.
  1264. - `"neither"`: neither `min_val` nor `max_val` are included in the
  1265. valid interval. It is equivalent to the interval `( min_val, max_val )`.
  1266. Returns
  1267. -------
  1268. x : numbers.Number
  1269. The validated number.
  1270. Raises
  1271. ------
  1272. TypeError
  1273. If the parameter's type does not match the desired type.
  1274. ValueError
  1275. If the parameter's value violates the given bounds.
  1276. If `min_val`, `max_val` and `include_boundaries` are inconsistent.
  1277. """
  1278. def type_name(t):
  1279. """Convert type into humman readable string."""
  1280. module = t.__module__
  1281. qualname = t.__qualname__
  1282. if module == "builtins":
  1283. return qualname
  1284. elif t == numbers.Real:
  1285. return "float"
  1286. elif t == numbers.Integral:
  1287. return "int"
  1288. return f"{module}.{qualname}"
  1289. if not isinstance(x, target_type):
  1290. if isinstance(target_type, tuple):
  1291. types_str = ", ".join(type_name(t) for t in target_type)
  1292. target_type_str = f"{{{types_str}}}"
  1293. else:
  1294. target_type_str = type_name(target_type)
  1295. raise TypeError(
  1296. f"{name} must be an instance of {target_type_str}, not"
  1297. f" {type(x).__qualname__}."
  1298. )
  1299. expected_include_boundaries = ("left", "right", "both", "neither")
  1300. if include_boundaries not in expected_include_boundaries:
  1301. raise ValueError(
  1302. f"Unknown value for `include_boundaries`: {repr(include_boundaries)}. "
  1303. f"Possible values are: {expected_include_boundaries}."
  1304. )
  1305. if max_val is None and include_boundaries == "right":
  1306. raise ValueError(
  1307. "`include_boundaries`='right' without specifying explicitly `max_val` "
  1308. "is inconsistent."
  1309. )
  1310. if min_val is None and include_boundaries == "left":
  1311. raise ValueError(
  1312. "`include_boundaries`='left' without specifying explicitly `min_val` "
  1313. "is inconsistent."
  1314. )
  1315. comparison_operator = (
  1316. operator.lt if include_boundaries in ("left", "both") else operator.le
  1317. )
  1318. if min_val is not None and comparison_operator(x, min_val):
  1319. raise ValueError(
  1320. f"{name} == {x}, must be"
  1321. f" {'>=' if include_boundaries in ('left', 'both') else '>'} {min_val}."
  1322. )
  1323. comparison_operator = (
  1324. operator.gt if include_boundaries in ("right", "both") else operator.ge
  1325. )
  1326. if max_val is not None and comparison_operator(x, max_val):
  1327. raise ValueError(
  1328. f"{name} == {x}, must be"
  1329. f" {'<=' if include_boundaries in ('right', 'both') else '<'} {max_val}."
  1330. )
  1331. return x
  1332. def _check_psd_eigenvalues(lambdas, enable_warnings=False):
  1333. """Check the eigenvalues of a positive semidefinite (PSD) matrix.
  1334. Checks the provided array of PSD matrix eigenvalues for numerical or
  1335. conditioning issues and returns a fixed validated version. This method
  1336. should typically be used if the PSD matrix is user-provided (e.g. a
  1337. Gram matrix) or computed using a user-provided dissimilarity metric
  1338. (e.g. kernel function), or if the decomposition process uses approximation
  1339. methods (randomized SVD, etc.).
  1340. It checks for three things:
  1341. - that there are no significant imaginary parts in eigenvalues (more than
  1342. 1e-5 times the maximum real part). If this check fails, it raises a
  1343. ``ValueError``. Otherwise all non-significant imaginary parts that may
  1344. remain are set to zero. This operation is traced with a
  1345. ``PositiveSpectrumWarning`` when ``enable_warnings=True``.
  1346. - that eigenvalues are not all negative. If this check fails, it raises a
  1347. ``ValueError``
  1348. - that there are no significant negative eigenvalues with absolute value
  1349. more than 1e-10 (1e-6) and more than 1e-5 (5e-3) times the largest
  1350. positive eigenvalue in double (simple) precision. If this check fails,
  1351. it raises a ``ValueError``. Otherwise all negative eigenvalues that may
  1352. remain are set to zero. This operation is traced with a
  1353. ``PositiveSpectrumWarning`` when ``enable_warnings=True``.
  1354. Finally, all the positive eigenvalues that are too small (with a value
  1355. smaller than the maximum eigenvalue multiplied by 1e-12 (2e-7)) are set to
  1356. zero. This operation is traced with a ``PositiveSpectrumWarning`` when
  1357. ``enable_warnings=True``.
  1358. Parameters
  1359. ----------
  1360. lambdas : array-like of shape (n_eigenvalues,)
  1361. Array of eigenvalues to check / fix.
  1362. enable_warnings : bool, default=False
  1363. When this is set to ``True``, a ``PositiveSpectrumWarning`` will be
  1364. raised when there are imaginary parts, negative eigenvalues, or
  1365. extremely small non-zero eigenvalues. Otherwise no warning will be
  1366. raised. In both cases, imaginary parts, negative eigenvalues, and
  1367. extremely small non-zero eigenvalues will be set to zero.
  1368. Returns
  1369. -------
  1370. lambdas_fixed : ndarray of shape (n_eigenvalues,)
  1371. A fixed validated copy of the array of eigenvalues.
  1372. Examples
  1373. --------
  1374. >>> from sklearn.utils.validation import _check_psd_eigenvalues
  1375. >>> _check_psd_eigenvalues([1, 2]) # nominal case
  1376. array([1, 2])
  1377. >>> _check_psd_eigenvalues([5, 5j]) # significant imag part
  1378. Traceback (most recent call last):
  1379. ...
  1380. ValueError: There are significant imaginary parts in eigenvalues (1
  1381. of the maximum real part). Either the matrix is not PSD, or there was
  1382. an issue while computing the eigendecomposition of the matrix.
  1383. >>> _check_psd_eigenvalues([5, 5e-5j]) # insignificant imag part
  1384. array([5., 0.])
  1385. >>> _check_psd_eigenvalues([-5, -1]) # all negative
  1386. Traceback (most recent call last):
  1387. ...
  1388. ValueError: All eigenvalues are negative (maximum is -1). Either the
  1389. matrix is not PSD, or there was an issue while computing the
  1390. eigendecomposition of the matrix.
  1391. >>> _check_psd_eigenvalues([5, -1]) # significant negative
  1392. Traceback (most recent call last):
  1393. ...
  1394. ValueError: There are significant negative eigenvalues (0.2 of the
  1395. maximum positive). Either the matrix is not PSD, or there was an issue
  1396. while computing the eigendecomposition of the matrix.
  1397. >>> _check_psd_eigenvalues([5, -5e-5]) # insignificant negative
  1398. array([5., 0.])
  1399. >>> _check_psd_eigenvalues([5, 4e-12]) # bad conditioning (too small)
  1400. array([5., 0.])
  1401. """
  1402. lambdas = np.array(lambdas)
  1403. is_double_precision = lambdas.dtype == np.float64
  1404. # note: the minimum value available is
  1405. # - single-precision: np.finfo('float32').eps = 1.2e-07
  1406. # - double-precision: np.finfo('float64').eps = 2.2e-16
  1407. # the various thresholds used for validation
  1408. # we may wish to change the value according to precision.
  1409. significant_imag_ratio = 1e-5
  1410. significant_neg_ratio = 1e-5 if is_double_precision else 5e-3
  1411. significant_neg_value = 1e-10 if is_double_precision else 1e-6
  1412. small_pos_ratio = 1e-12 if is_double_precision else 2e-7
  1413. # Check that there are no significant imaginary parts
  1414. if not np.isreal(lambdas).all():
  1415. max_imag_abs = np.abs(np.imag(lambdas)).max()
  1416. max_real_abs = np.abs(np.real(lambdas)).max()
  1417. if max_imag_abs > significant_imag_ratio * max_real_abs:
  1418. raise ValueError(
  1419. "There are significant imaginary parts in eigenvalues (%g "
  1420. "of the maximum real part). Either the matrix is not PSD, or "
  1421. "there was an issue while computing the eigendecomposition "
  1422. "of the matrix." % (max_imag_abs / max_real_abs)
  1423. )
  1424. # warn about imaginary parts being removed
  1425. if enable_warnings:
  1426. warnings.warn(
  1427. "There are imaginary parts in eigenvalues (%g "
  1428. "of the maximum real part). Either the matrix is not"
  1429. " PSD, or there was an issue while computing the "
  1430. "eigendecomposition of the matrix. Only the real "
  1431. "parts will be kept." % (max_imag_abs / max_real_abs),
  1432. PositiveSpectrumWarning,
  1433. )
  1434. # Remove all imaginary parts (even if zero)
  1435. lambdas = np.real(lambdas)
  1436. # Check that there are no significant negative eigenvalues
  1437. max_eig = lambdas.max()
  1438. if max_eig < 0:
  1439. raise ValueError(
  1440. "All eigenvalues are negative (maximum is %g). "
  1441. "Either the matrix is not PSD, or there was an "
  1442. "issue while computing the eigendecomposition of "
  1443. "the matrix." % max_eig
  1444. )
  1445. else:
  1446. min_eig = lambdas.min()
  1447. if (
  1448. min_eig < -significant_neg_ratio * max_eig
  1449. and min_eig < -significant_neg_value
  1450. ):
  1451. raise ValueError(
  1452. "There are significant negative eigenvalues (%g"
  1453. " of the maximum positive). Either the matrix is "
  1454. "not PSD, or there was an issue while computing "
  1455. "the eigendecomposition of the matrix." % (-min_eig / max_eig)
  1456. )
  1457. elif min_eig < 0:
  1458. # Remove all negative values and warn about it
  1459. if enable_warnings:
  1460. warnings.warn(
  1461. "There are negative eigenvalues (%g of the "
  1462. "maximum positive). Either the matrix is not "
  1463. "PSD, or there was an issue while computing the"
  1464. " eigendecomposition of the matrix. Negative "
  1465. "eigenvalues will be replaced with 0." % (-min_eig / max_eig),
  1466. PositiveSpectrumWarning,
  1467. )
  1468. lambdas[lambdas < 0] = 0
  1469. # Check for conditioning (small positive non-zeros)
  1470. too_small_lambdas = (0 < lambdas) & (lambdas < small_pos_ratio * max_eig)
  1471. if too_small_lambdas.any():
  1472. if enable_warnings:
  1473. warnings.warn(
  1474. "Badly conditioned PSD matrix spectrum: the largest "
  1475. "eigenvalue is more than %g times the smallest. "
  1476. "Small eigenvalues will be replaced with 0."
  1477. "" % (1 / small_pos_ratio),
  1478. PositiveSpectrumWarning,
  1479. )
  1480. lambdas[too_small_lambdas] = 0
  1481. return lambdas
  1482. def _check_sample_weight(
  1483. sample_weight, X, dtype=None, copy=False, only_non_negative=False
  1484. ):
  1485. """Validate sample weights.
  1486. Note that passing sample_weight=None will output an array of ones.
  1487. Therefore, in some cases, you may want to protect the call with:
  1488. if sample_weight is not None:
  1489. sample_weight = _check_sample_weight(...)
  1490. Parameters
  1491. ----------
  1492. sample_weight : {ndarray, Number or None}, shape (n_samples,)
  1493. Input sample weights.
  1494. X : {ndarray, list, sparse matrix}
  1495. Input data.
  1496. only_non_negative : bool, default=False,
  1497. Whether or not the weights are expected to be non-negative.
  1498. .. versionadded:: 1.0
  1499. dtype : dtype, default=None
  1500. dtype of the validated `sample_weight`.
  1501. If None, and the input `sample_weight` is an array, the dtype of the
  1502. input is preserved; otherwise an array with the default numpy dtype
  1503. is be allocated. If `dtype` is not one of `float32`, `float64`,
  1504. `None`, the output will be of dtype `float64`.
  1505. copy : bool, default=False
  1506. If True, a copy of sample_weight will be created.
  1507. Returns
  1508. -------
  1509. sample_weight : ndarray of shape (n_samples,)
  1510. Validated sample weight. It is guaranteed to be "C" contiguous.
  1511. """
  1512. n_samples = _num_samples(X)
  1513. if dtype is not None and dtype not in [np.float32, np.float64]:
  1514. dtype = np.float64
  1515. if sample_weight is None:
  1516. sample_weight = np.ones(n_samples, dtype=dtype)
  1517. elif isinstance(sample_weight, numbers.Number):
  1518. sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
  1519. else:
  1520. if dtype is None:
  1521. dtype = [np.float64, np.float32]
  1522. sample_weight = check_array(
  1523. sample_weight,
  1524. accept_sparse=False,
  1525. ensure_2d=False,
  1526. dtype=dtype,
  1527. order="C",
  1528. copy=copy,
  1529. input_name="sample_weight",
  1530. )
  1531. if sample_weight.ndim != 1:
  1532. raise ValueError("Sample weights must be 1D array or scalar")
  1533. if sample_weight.shape != (n_samples,):
  1534. raise ValueError(
  1535. "sample_weight.shape == {}, expected {}!".format(
  1536. sample_weight.shape, (n_samples,)
  1537. )
  1538. )
  1539. if only_non_negative:
  1540. check_non_negative(sample_weight, "`sample_weight`")
  1541. return sample_weight
  1542. def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9):
  1543. """Check allclose for sparse and dense data.
  1544. Both x and y need to be either sparse or dense, they
  1545. can't be mixed.
  1546. Parameters
  1547. ----------
  1548. x : {array-like, sparse matrix}
  1549. First array to compare.
  1550. y : {array-like, sparse matrix}
  1551. Second array to compare.
  1552. rtol : float, default=1e-7
  1553. Relative tolerance; see numpy.allclose.
  1554. atol : float, default=1e-9
  1555. absolute tolerance; see numpy.allclose. Note that the default here is
  1556. more tolerant than the default for numpy.testing.assert_allclose, where
  1557. atol=0.
  1558. """
  1559. if sp.issparse(x) and sp.issparse(y):
  1560. x = x.tocsr()
  1561. y = y.tocsr()
  1562. x.sum_duplicates()
  1563. y.sum_duplicates()
  1564. return (
  1565. np.array_equal(x.indices, y.indices)
  1566. and np.array_equal(x.indptr, y.indptr)
  1567. and np.allclose(x.data, y.data, rtol=rtol, atol=atol)
  1568. )
  1569. elif not sp.issparse(x) and not sp.issparse(y):
  1570. return np.allclose(x, y, rtol=rtol, atol=atol)
  1571. raise ValueError(
  1572. "Can only compare two sparse matrices, not a sparse matrix and an array"
  1573. )
  1574. def _check_response_method(estimator, response_method):
  1575. """Check if `response_method` is available in estimator and return it.
  1576. .. versionadded:: 1.3
  1577. Parameters
  1578. ----------
  1579. estimator : estimator instance
  1580. Classifier or regressor to check.
  1581. response_method : {"predict_proba", "decision_function", "predict"} or \
  1582. list of such str
  1583. Specifies the response method to use get prediction from an estimator
  1584. (i.e. :term:`predict_proba`, :term:`decision_function` or
  1585. :term:`predict`). Possible choices are:
  1586. - if `str`, it corresponds to the name to the method to return;
  1587. - if a list of `str`, it provides the method names in order of
  1588. preference. The method returned corresponds to the first method in
  1589. the list and which is implemented by `estimator`.
  1590. Returns
  1591. -------
  1592. prediction_method : callable
  1593. Prediction method of estimator.
  1594. Raises
  1595. ------
  1596. AttributeError
  1597. If `response_method` is not available in `estimator`.
  1598. """
  1599. if isinstance(response_method, str):
  1600. list_methods = [response_method]
  1601. else:
  1602. list_methods = response_method
  1603. prediction_method = [getattr(estimator, method, None) for method in list_methods]
  1604. prediction_method = reduce(lambda x, y: x or y, prediction_method)
  1605. if prediction_method is None:
  1606. raise AttributeError(
  1607. f"{estimator.__class__.__name__} has none of the following attributes: "
  1608. f"{', '.join(list_methods)}."
  1609. )
  1610. return prediction_method
  1611. def _check_fit_params(X, fit_params, indices=None):
  1612. """Check and validate the parameters passed during `fit`.
  1613. Parameters
  1614. ----------
  1615. X : array-like of shape (n_samples, n_features)
  1616. Data array.
  1617. fit_params : dict
  1618. Dictionary containing the parameters passed at fit.
  1619. indices : array-like of shape (n_samples,), default=None
  1620. Indices to be selected if the parameter has the same size as `X`.
  1621. Returns
  1622. -------
  1623. fit_params_validated : dict
  1624. Validated parameters. We ensure that the values support indexing.
  1625. """
  1626. from . import _safe_indexing
  1627. fit_params_validated = {}
  1628. for param_key, param_value in fit_params.items():
  1629. if not _is_arraylike(param_value) or _num_samples(param_value) != _num_samples(
  1630. X
  1631. ):
  1632. # Non-indexable pass-through (for now for backward-compatibility).
  1633. # https://github.com/scikit-learn/scikit-learn/issues/15805
  1634. fit_params_validated[param_key] = param_value
  1635. else:
  1636. # Any other fit_params should support indexing
  1637. # (e.g. for cross-validation).
  1638. fit_params_validated[param_key] = _make_indexable(param_value)
  1639. fit_params_validated[param_key] = _safe_indexing(
  1640. fit_params_validated[param_key], indices
  1641. )
  1642. return fit_params_validated
  1643. def _get_feature_names(X):
  1644. """Get feature names from X.
  1645. Support for other array containers should place its implementation here.
  1646. Parameters
  1647. ----------
  1648. X : {ndarray, dataframe} of shape (n_samples, n_features)
  1649. Array container to extract feature names.
  1650. - pandas dataframe : The columns will be considered to be feature
  1651. names. If the dataframe contains non-string feature names, `None` is
  1652. returned.
  1653. - All other array containers will return `None`.
  1654. Returns
  1655. -------
  1656. names: ndarray or None
  1657. Feature names of `X`. Unrecognized array containers will return `None`.
  1658. """
  1659. feature_names = None
  1660. # extract feature names for support array containers
  1661. if hasattr(X, "columns"):
  1662. feature_names = np.asarray(X.columns, dtype=object)
  1663. if feature_names is None or len(feature_names) == 0:
  1664. return
  1665. types = sorted(t.__qualname__ for t in set(type(v) for v in feature_names))
  1666. # mixed type of string and non-string is not supported
  1667. if len(types) > 1 and "str" in types:
  1668. raise TypeError(
  1669. "Feature names are only supported if all input features have string names, "
  1670. f"but your input has {types} as feature name / column name types. "
  1671. "If you want feature names to be stored and validated, you must convert "
  1672. "them all to strings, by using X.columns = X.columns.astype(str) for "
  1673. "example. Otherwise you can remove feature / column names from your input "
  1674. "data, or convert them all to a non-string data type."
  1675. )
  1676. # Only feature names of all strings are supported
  1677. if len(types) == 1 and types[0] == "str":
  1678. return feature_names
  1679. def _check_feature_names_in(estimator, input_features=None, *, generate_names=True):
  1680. """Check `input_features` and generate names if needed.
  1681. Commonly used in :term:`get_feature_names_out`.
  1682. Parameters
  1683. ----------
  1684. input_features : array-like of str or None, default=None
  1685. Input features.
  1686. - If `input_features` is `None`, then `feature_names_in_` is
  1687. used as feature names in. If `feature_names_in_` is not defined,
  1688. then the following input feature names are generated:
  1689. `["x0", "x1", ..., "x(n_features_in_ - 1)"]`.
  1690. - If `input_features` is an array-like, then `input_features` must
  1691. match `feature_names_in_` if `feature_names_in_` is defined.
  1692. generate_names : bool, default=True
  1693. Whether to generate names when `input_features` is `None` and
  1694. `estimator.feature_names_in_` is not defined. This is useful for transformers
  1695. that validates `input_features` but do not require them in
  1696. :term:`get_feature_names_out` e.g. `PCA`.
  1697. Returns
  1698. -------
  1699. feature_names_in : ndarray of str or `None`
  1700. Feature names in.
  1701. """
  1702. feature_names_in_ = getattr(estimator, "feature_names_in_", None)
  1703. n_features_in_ = getattr(estimator, "n_features_in_", None)
  1704. if input_features is not None:
  1705. input_features = np.asarray(input_features, dtype=object)
  1706. if feature_names_in_ is not None and not np.array_equal(
  1707. feature_names_in_, input_features
  1708. ):
  1709. raise ValueError("input_features is not equal to feature_names_in_")
  1710. if n_features_in_ is not None and len(input_features) != n_features_in_:
  1711. raise ValueError(
  1712. "input_features should have length equal to number of "
  1713. f"features ({n_features_in_}), got {len(input_features)}"
  1714. )
  1715. return input_features
  1716. if feature_names_in_ is not None:
  1717. return feature_names_in_
  1718. if not generate_names:
  1719. return
  1720. # Generates feature names if `n_features_in_` is defined
  1721. if n_features_in_ is None:
  1722. raise ValueError("Unable to generate feature names without n_features_in_")
  1723. return np.asarray([f"x{i}" for i in range(n_features_in_)], dtype=object)
  1724. def _generate_get_feature_names_out(estimator, n_features_out, input_features=None):
  1725. """Generate feature names out for estimator using the estimator name as the prefix.
  1726. The input_feature names are validated but not used. This function is useful
  1727. for estimators that generate their own names based on `n_features_out`, i.e. PCA.
  1728. Parameters
  1729. ----------
  1730. estimator : estimator instance
  1731. Estimator producing output feature names.
  1732. n_feature_out : int
  1733. Number of feature names out.
  1734. input_features : array-like of str or None, default=None
  1735. Only used to validate feature names with `estimator.feature_names_in_`.
  1736. Returns
  1737. -------
  1738. feature_names_in : ndarray of str or `None`
  1739. Feature names in.
  1740. """
  1741. _check_feature_names_in(estimator, input_features, generate_names=False)
  1742. estimator_name = estimator.__class__.__name__.lower()
  1743. return np.asarray(
  1744. [f"{estimator_name}{i}" for i in range(n_features_out)], dtype=object
  1745. )
  1746. def _check_monotonic_cst(estimator, monotonic_cst=None):
  1747. """Check the monotonic constraints and return the corresponding array.
  1748. This helper function should be used in the `fit` method of an estimator
  1749. that supports monotonic constraints and called after the estimator has
  1750. introspected input data to set the `n_features_in_` and optionally the
  1751. `feature_names_in_` attributes.
  1752. .. versionadded:: 1.2
  1753. Parameters
  1754. ----------
  1755. estimator : estimator instance
  1756. monotonic_cst : array-like of int, dict of str or None, default=None
  1757. Monotonic constraints for the features.
  1758. - If array-like, then it should contain only -1, 0 or 1. Each value
  1759. will be checked to be in [-1, 0, 1]. If a value is -1, then the
  1760. corresponding feature is required to be monotonically decreasing.
  1761. - If dict, then it the keys should be the feature names occurring in
  1762. `estimator.feature_names_in_` and the values should be -1, 0 or 1.
  1763. - If None, then an array of 0s will be allocated.
  1764. Returns
  1765. -------
  1766. monotonic_cst : ndarray of int
  1767. Monotonic constraints for each feature.
  1768. """
  1769. original_monotonic_cst = monotonic_cst
  1770. if monotonic_cst is None or isinstance(monotonic_cst, dict):
  1771. monotonic_cst = np.full(
  1772. shape=estimator.n_features_in_,
  1773. fill_value=0,
  1774. dtype=np.int8,
  1775. )
  1776. if isinstance(original_monotonic_cst, dict):
  1777. if not hasattr(estimator, "feature_names_in_"):
  1778. raise ValueError(
  1779. f"{estimator.__class__.__name__} was not fitted on data "
  1780. "with feature names. Pass monotonic_cst as an integer "
  1781. "array instead."
  1782. )
  1783. unexpected_feature_names = list(
  1784. set(original_monotonic_cst) - set(estimator.feature_names_in_)
  1785. )
  1786. unexpected_feature_names.sort() # deterministic error message
  1787. n_unexpeced = len(unexpected_feature_names)
  1788. if unexpected_feature_names:
  1789. if len(unexpected_feature_names) > 5:
  1790. unexpected_feature_names = unexpected_feature_names[:5]
  1791. unexpected_feature_names.append("...")
  1792. raise ValueError(
  1793. f"monotonic_cst contains {n_unexpeced} unexpected feature "
  1794. f"names: {unexpected_feature_names}."
  1795. )
  1796. for feature_idx, feature_name in enumerate(estimator.feature_names_in_):
  1797. if feature_name in original_monotonic_cst:
  1798. cst = original_monotonic_cst[feature_name]
  1799. if cst not in [-1, 0, 1]:
  1800. raise ValueError(
  1801. f"monotonic_cst['{feature_name}'] must be either "
  1802. f"-1, 0 or 1. Got {cst!r}."
  1803. )
  1804. monotonic_cst[feature_idx] = cst
  1805. else:
  1806. unexpected_cst = np.setdiff1d(monotonic_cst, [-1, 0, 1])
  1807. if unexpected_cst.shape[0]:
  1808. raise ValueError(
  1809. "monotonic_cst must be an array-like of -1, 0 or 1. Observed "
  1810. f"values: {unexpected_cst.tolist()}."
  1811. )
  1812. monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8)
  1813. if monotonic_cst.shape[0] != estimator.n_features_in_:
  1814. raise ValueError(
  1815. f"monotonic_cst has shape {monotonic_cst.shape} but the input data "
  1816. f"X has {estimator.n_features_in_} features."
  1817. )
  1818. return monotonic_cst
  1819. def _check_pos_label_consistency(pos_label, y_true):
  1820. """Check if `pos_label` need to be specified or not.
  1821. In binary classification, we fix `pos_label=1` if the labels are in the set
  1822. {-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the
  1823. `pos_label` parameters.
  1824. Parameters
  1825. ----------
  1826. pos_label : int, float, bool, str or None
  1827. The positive label.
  1828. y_true : ndarray of shape (n_samples,)
  1829. The target vector.
  1830. Returns
  1831. -------
  1832. pos_label : int, float, bool or str
  1833. If `pos_label` can be inferred, it will be returned.
  1834. Raises
  1835. ------
  1836. ValueError
  1837. In the case that `y_true` does not have label in {-1, 1} or {0, 1},
  1838. it will raise a `ValueError`.
  1839. """
  1840. # ensure binary classification if pos_label is not specified
  1841. # classes.dtype.kind in ('O', 'U', 'S') is required to avoid
  1842. # triggering a FutureWarning by calling np.array_equal(a, b)
  1843. # when elements in the two arrays are not comparable.
  1844. classes = np.unique(y_true)
  1845. if pos_label is None and (
  1846. classes.dtype.kind in "OUS"
  1847. or not (
  1848. np.array_equal(classes, [0, 1])
  1849. or np.array_equal(classes, [-1, 1])
  1850. or np.array_equal(classes, [0])
  1851. or np.array_equal(classes, [-1])
  1852. or np.array_equal(classes, [1])
  1853. )
  1854. ):
  1855. classes_repr = ", ".join([repr(c) for c in classes.tolist()])
  1856. raise ValueError(
  1857. f"y_true takes value in {{{classes_repr}}} and pos_label is not "
  1858. "specified: either make y_true take value in {0, 1} or "
  1859. "{-1, 1} or pass pos_label explicitly."
  1860. )
  1861. elif pos_label is None:
  1862. pos_label = 1
  1863. return pos_label