_testing.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035
  1. """Testing utilities."""
  2. # Copyright (c) 2011, 2012
  3. # Authors: Pietro Berkes,
  4. # Andreas Muller
  5. # Mathieu Blondel
  6. # Olivier Grisel
  7. # Arnaud Joly
  8. # Denis Engemann
  9. # Giorgio Patrini
  10. # Thierry Guillemot
  11. # License: BSD 3 clause
  12. import atexit
  13. import contextlib
  14. import functools
  15. import inspect
  16. import os
  17. import os.path as op
  18. import re
  19. import shutil
  20. import sys
  21. import tempfile
  22. import unittest
  23. import warnings
  24. from collections.abc import Iterable, Sequence
  25. from functools import wraps
  26. from inspect import signature
  27. from subprocess import STDOUT, CalledProcessError, TimeoutExpired, check_output
  28. from unittest import TestCase
  29. import joblib
  30. import numpy as np
  31. import scipy as sp
  32. from numpy.testing import assert_allclose as np_assert_allclose
  33. from numpy.testing import (
  34. assert_almost_equal,
  35. assert_approx_equal,
  36. assert_array_almost_equal,
  37. assert_array_equal,
  38. assert_array_less,
  39. assert_no_warnings,
  40. )
  41. import sklearn
  42. from sklearn.utils import (
  43. _IS_32BIT,
  44. IS_PYPY,
  45. _in_unstable_openblas_configuration,
  46. )
  47. from sklearn.utils._array_api import _check_array_api_dispatch
  48. from sklearn.utils.fixes import threadpool_info
  49. from sklearn.utils.multiclass import check_classification_targets
  50. from sklearn.utils.validation import (
  51. check_array,
  52. check_is_fitted,
  53. check_X_y,
  54. )
  55. __all__ = [
  56. "assert_raises",
  57. "assert_raises_regexp",
  58. "assert_array_equal",
  59. "assert_almost_equal",
  60. "assert_array_almost_equal",
  61. "assert_array_less",
  62. "assert_approx_equal",
  63. "assert_allclose",
  64. "assert_run_python_script",
  65. "assert_no_warnings",
  66. "SkipTest",
  67. ]
  68. _dummy = TestCase("__init__")
  69. assert_raises = _dummy.assertRaises
  70. SkipTest = unittest.case.SkipTest
  71. assert_dict_equal = _dummy.assertDictEqual
  72. assert_raises_regex = _dummy.assertRaisesRegex
  73. # assert_raises_regexp is deprecated in Python 3.4 in favor of
  74. # assert_raises_regex but lets keep the backward compat in scikit-learn with
  75. # the old name for now
  76. assert_raises_regexp = assert_raises_regex
  77. def ignore_warnings(obj=None, category=Warning):
  78. """Context manager and decorator to ignore warnings.
  79. Note: Using this (in both variants) will clear all warnings
  80. from all python modules loaded. In case you need to test
  81. cross-module-warning-logging, this is not your tool of choice.
  82. Parameters
  83. ----------
  84. obj : callable, default=None
  85. callable where you want to ignore the warnings.
  86. category : warning class, default=Warning
  87. The category to filter. If Warning, all categories will be muted.
  88. Examples
  89. --------
  90. >>> import warnings
  91. >>> from sklearn.utils._testing import ignore_warnings
  92. >>> with ignore_warnings():
  93. ... warnings.warn('buhuhuhu')
  94. >>> def nasty_warn():
  95. ... warnings.warn('buhuhuhu')
  96. ... print(42)
  97. >>> ignore_warnings(nasty_warn)()
  98. 42
  99. """
  100. if isinstance(obj, type) and issubclass(obj, Warning):
  101. # Avoid common pitfall of passing category as the first positional
  102. # argument which result in the test not being run
  103. warning_name = obj.__name__
  104. raise ValueError(
  105. "'obj' should be a callable where you want to ignore warnings. "
  106. "You passed a warning class instead: 'obj={warning_name}'. "
  107. "If you want to pass a warning class to ignore_warnings, "
  108. "you should use 'category={warning_name}'".format(warning_name=warning_name)
  109. )
  110. elif callable(obj):
  111. return _IgnoreWarnings(category=category)(obj)
  112. else:
  113. return _IgnoreWarnings(category=category)
  114. class _IgnoreWarnings:
  115. """Improved and simplified Python warnings context manager and decorator.
  116. This class allows the user to ignore the warnings raised by a function.
  117. Copied from Python 2.7.5 and modified as required.
  118. Parameters
  119. ----------
  120. category : tuple of warning class, default=Warning
  121. The category to filter. By default, all the categories will be muted.
  122. """
  123. def __init__(self, category):
  124. self._record = True
  125. self._module = sys.modules["warnings"]
  126. self._entered = False
  127. self.log = []
  128. self.category = category
  129. def __call__(self, fn):
  130. """Decorator to catch and hide warnings without visual nesting."""
  131. @wraps(fn)
  132. def wrapper(*args, **kwargs):
  133. with warnings.catch_warnings():
  134. warnings.simplefilter("ignore", self.category)
  135. return fn(*args, **kwargs)
  136. return wrapper
  137. def __repr__(self):
  138. args = []
  139. if self._record:
  140. args.append("record=True")
  141. if self._module is not sys.modules["warnings"]:
  142. args.append("module=%r" % self._module)
  143. name = type(self).__name__
  144. return "%s(%s)" % (name, ", ".join(args))
  145. def __enter__(self):
  146. if self._entered:
  147. raise RuntimeError("Cannot enter %r twice" % self)
  148. self._entered = True
  149. self._filters = self._module.filters
  150. self._module.filters = self._filters[:]
  151. self._showwarning = self._module.showwarning
  152. warnings.simplefilter("ignore", self.category)
  153. def __exit__(self, *exc_info):
  154. if not self._entered:
  155. raise RuntimeError("Cannot exit %r without entering first" % self)
  156. self._module.filters = self._filters
  157. self._module.showwarning = self._showwarning
  158. self.log[:] = []
  159. def assert_raise_message(exceptions, message, function, *args, **kwargs):
  160. """Helper function to test the message raised in an exception.
  161. Given an exception, a callable to raise the exception, and
  162. a message string, tests that the correct exception is raised and
  163. that the message is a substring of the error thrown. Used to test
  164. that the specific message thrown during an exception is correct.
  165. Parameters
  166. ----------
  167. exceptions : exception or tuple of exception
  168. An Exception object.
  169. message : str
  170. The error message or a substring of the error message.
  171. function : callable
  172. Callable object to raise error.
  173. *args : the positional arguments to `function`.
  174. **kwargs : the keyword arguments to `function`.
  175. """
  176. try:
  177. function(*args, **kwargs)
  178. except exceptions as e:
  179. error_message = str(e)
  180. if message not in error_message:
  181. raise AssertionError(
  182. "Error message does not include the expected"
  183. " string: %r. Observed error message: %r" % (message, error_message)
  184. )
  185. else:
  186. # concatenate exception names
  187. if isinstance(exceptions, tuple):
  188. names = " or ".join(e.__name__ for e in exceptions)
  189. else:
  190. names = exceptions.__name__
  191. raise AssertionError("%s not raised by %s" % (names, function.__name__))
  192. def assert_allclose(
  193. actual, desired, rtol=None, atol=0.0, equal_nan=True, err_msg="", verbose=True
  194. ):
  195. """dtype-aware variant of numpy.testing.assert_allclose
  196. This variant introspects the least precise floating point dtype
  197. in the input argument and automatically sets the relative tolerance
  198. parameter to 1e-4 float32 and use 1e-7 otherwise (typically float64
  199. in scikit-learn).
  200. `atol` is always left to 0. by default. It should be adjusted manually
  201. to an assertion-specific value in case there are null values expected
  202. in `desired`.
  203. The aggregate tolerance is `atol + rtol * abs(desired)`.
  204. Parameters
  205. ----------
  206. actual : array_like
  207. Array obtained.
  208. desired : array_like
  209. Array desired.
  210. rtol : float, optional, default=None
  211. Relative tolerance.
  212. If None, it is set based on the provided arrays' dtypes.
  213. atol : float, optional, default=0.
  214. Absolute tolerance.
  215. equal_nan : bool, optional, default=True
  216. If True, NaNs will compare equal.
  217. err_msg : str, optional, default=''
  218. The error message to be printed in case of failure.
  219. verbose : bool, optional, default=True
  220. If True, the conflicting values are appended to the error message.
  221. Raises
  222. ------
  223. AssertionError
  224. If actual and desired are not equal up to specified precision.
  225. See Also
  226. --------
  227. numpy.testing.assert_allclose
  228. Examples
  229. --------
  230. >>> import numpy as np
  231. >>> from sklearn.utils._testing import assert_allclose
  232. >>> x = [1e-5, 1e-3, 1e-1]
  233. >>> y = np.arccos(np.cos(x))
  234. >>> assert_allclose(x, y, rtol=1e-5, atol=0)
  235. >>> a = np.full(shape=10, fill_value=1e-5, dtype=np.float32)
  236. >>> assert_allclose(a, 1e-5)
  237. """
  238. dtypes = []
  239. actual, desired = np.asanyarray(actual), np.asanyarray(desired)
  240. dtypes = [actual.dtype, desired.dtype]
  241. if rtol is None:
  242. rtols = [1e-4 if dtype == np.float32 else 1e-7 for dtype in dtypes]
  243. rtol = max(rtols)
  244. np_assert_allclose(
  245. actual,
  246. desired,
  247. rtol=rtol,
  248. atol=atol,
  249. equal_nan=equal_nan,
  250. err_msg=err_msg,
  251. verbose=verbose,
  252. )
  253. def assert_allclose_dense_sparse(x, y, rtol=1e-07, atol=1e-9, err_msg=""):
  254. """Assert allclose for sparse and dense data.
  255. Both x and y need to be either sparse or dense, they
  256. can't be mixed.
  257. Parameters
  258. ----------
  259. x : {array-like, sparse matrix}
  260. First array to compare.
  261. y : {array-like, sparse matrix}
  262. Second array to compare.
  263. rtol : float, default=1e-07
  264. relative tolerance; see numpy.allclose.
  265. atol : float, default=1e-9
  266. absolute tolerance; see numpy.allclose. Note that the default here is
  267. more tolerant than the default for numpy.testing.assert_allclose, where
  268. atol=0.
  269. err_msg : str, default=''
  270. Error message to raise.
  271. """
  272. if sp.sparse.issparse(x) and sp.sparse.issparse(y):
  273. x = x.tocsr()
  274. y = y.tocsr()
  275. x.sum_duplicates()
  276. y.sum_duplicates()
  277. assert_array_equal(x.indices, y.indices, err_msg=err_msg)
  278. assert_array_equal(x.indptr, y.indptr, err_msg=err_msg)
  279. assert_allclose(x.data, y.data, rtol=rtol, atol=atol, err_msg=err_msg)
  280. elif not sp.sparse.issparse(x) and not sp.sparse.issparse(y):
  281. # both dense
  282. assert_allclose(x, y, rtol=rtol, atol=atol, err_msg=err_msg)
  283. else:
  284. raise ValueError(
  285. "Can only compare two sparse matrices, not a sparse matrix and an array."
  286. )
  287. def set_random_state(estimator, random_state=0):
  288. """Set random state of an estimator if it has the `random_state` param.
  289. Parameters
  290. ----------
  291. estimator : object
  292. The estimator.
  293. random_state : int, RandomState instance or None, default=0
  294. Pseudo random number generator state.
  295. Pass an int for reproducible results across multiple function calls.
  296. See :term:`Glossary <random_state>`.
  297. """
  298. if "random_state" in estimator.get_params():
  299. estimator.set_params(random_state=random_state)
  300. try:
  301. _check_array_api_dispatch(True)
  302. ARRAY_API_COMPAT_FUNCTIONAL = True
  303. except ImportError:
  304. ARRAY_API_COMPAT_FUNCTIONAL = False
  305. try:
  306. import pytest
  307. skip_if_32bit = pytest.mark.skipif(_IS_32BIT, reason="skipped on 32bit platforms")
  308. fails_if_pypy = pytest.mark.xfail(IS_PYPY, reason="not compatible with PyPy")
  309. fails_if_unstable_openblas = pytest.mark.xfail(
  310. _in_unstable_openblas_configuration(),
  311. reason="OpenBLAS is unstable for this configuration",
  312. )
  313. skip_if_no_parallel = pytest.mark.skipif(
  314. not joblib.parallel.mp, reason="joblib is in serial mode"
  315. )
  316. skip_if_array_api_compat_not_configured = pytest.mark.skipif(
  317. not ARRAY_API_COMPAT_FUNCTIONAL,
  318. reason="requires array_api_compat installed and a new enough version of NumPy",
  319. )
  320. # Decorator for tests involving both BLAS calls and multiprocessing.
  321. #
  322. # Under POSIX (e.g. Linux or OSX), using multiprocessing in conjunction
  323. # with some implementation of BLAS (or other libraries that manage an
  324. # internal posix thread pool) can cause a crash or a freeze of the Python
  325. # process.
  326. #
  327. # In practice all known packaged distributions (from Linux distros or
  328. # Anaconda) of BLAS under Linux seems to be safe. So we this problem seems
  329. # to only impact OSX users.
  330. #
  331. # This wrapper makes it possible to skip tests that can possibly cause
  332. # this crash under OS X with.
  333. #
  334. # Under Python 3.4+ it is possible to use the `forkserver` start method
  335. # for multiprocessing to avoid this issue. However it can cause pickling
  336. # errors on interactively defined functions. It therefore not enabled by
  337. # default.
  338. if_safe_multiprocessing_with_blas = pytest.mark.skipif(
  339. sys.platform == "darwin", reason="Possible multi-process bug with some BLAS"
  340. )
  341. except ImportError:
  342. pass
  343. def check_skip_network():
  344. if int(os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", 0)):
  345. raise SkipTest("Text tutorial requires large dataset download")
  346. def _delete_folder(folder_path, warn=False):
  347. """Utility function to cleanup a temporary folder if still existing.
  348. Copy from joblib.pool (for independence).
  349. """
  350. try:
  351. if os.path.exists(folder_path):
  352. # This can fail under windows,
  353. # but will succeed when called by atexit
  354. shutil.rmtree(folder_path)
  355. except OSError:
  356. if warn:
  357. warnings.warn("Could not delete temporary folder %s" % folder_path)
  358. class TempMemmap:
  359. """
  360. Parameters
  361. ----------
  362. data
  363. mmap_mode : str, default='r'
  364. """
  365. def __init__(self, data, mmap_mode="r"):
  366. self.mmap_mode = mmap_mode
  367. self.data = data
  368. def __enter__(self):
  369. data_read_only, self.temp_folder = create_memmap_backed_data(
  370. self.data, mmap_mode=self.mmap_mode, return_folder=True
  371. )
  372. return data_read_only
  373. def __exit__(self, exc_type, exc_val, exc_tb):
  374. _delete_folder(self.temp_folder)
  375. def _create_memmap_backed_array(array, filename, mmap_mode):
  376. # https://numpy.org/doc/stable/reference/generated/numpy.memmap.html
  377. fp = np.memmap(filename, dtype=array.dtype, mode="w+", shape=array.shape)
  378. fp[:] = array[:] # write array to memmap array
  379. fp.flush()
  380. memmap_backed_array = np.memmap(
  381. filename, dtype=array.dtype, mode=mmap_mode, shape=array.shape
  382. )
  383. return memmap_backed_array
  384. def _create_aligned_memmap_backed_arrays(data, mmap_mode, folder):
  385. if isinstance(data, np.ndarray):
  386. filename = op.join(folder, "data.dat")
  387. return _create_memmap_backed_array(data, filename, mmap_mode)
  388. if isinstance(data, Sequence) and all(
  389. isinstance(each, np.ndarray) for each in data
  390. ):
  391. return [
  392. _create_memmap_backed_array(
  393. array, op.join(folder, f"data{index}.dat"), mmap_mode
  394. )
  395. for index, array in enumerate(data)
  396. ]
  397. raise ValueError(
  398. "When creating aligned memmap-backed arrays, input must be a single array or a"
  399. " sequence of arrays"
  400. )
  401. def create_memmap_backed_data(data, mmap_mode="r", return_folder=False, aligned=False):
  402. """
  403. Parameters
  404. ----------
  405. data
  406. mmap_mode : str, default='r'
  407. return_folder : bool, default=False
  408. aligned : bool, default=False
  409. If True, if input is a single numpy array and if the input array is aligned,
  410. the memory mapped array will also be aligned. This is a workaround for
  411. https://github.com/joblib/joblib/issues/563.
  412. """
  413. temp_folder = tempfile.mkdtemp(prefix="sklearn_testing_")
  414. atexit.register(functools.partial(_delete_folder, temp_folder, warn=True))
  415. # OpenBLAS is known to segfault with unaligned data on the Prescott
  416. # architecture so force aligned=True on Prescott. For more details, see:
  417. # https://github.com/scipy/scipy/issues/14886
  418. has_prescott_openblas = any(
  419. True
  420. for info in threadpool_info()
  421. if info["internal_api"] == "openblas"
  422. # Prudently assume Prescott might be the architecture if it is unknown.
  423. and info.get("architecture", "prescott").lower() == "prescott"
  424. )
  425. if has_prescott_openblas:
  426. aligned = True
  427. if aligned:
  428. memmap_backed_data = _create_aligned_memmap_backed_arrays(
  429. data, mmap_mode, temp_folder
  430. )
  431. else:
  432. filename = op.join(temp_folder, "data.pkl")
  433. joblib.dump(data, filename)
  434. memmap_backed_data = joblib.load(filename, mmap_mode=mmap_mode)
  435. result = (
  436. memmap_backed_data if not return_folder else (memmap_backed_data, temp_folder)
  437. )
  438. return result
  439. # Utils to test docstrings
  440. def _get_args(function, varargs=False):
  441. """Helper to get function arguments."""
  442. try:
  443. params = signature(function).parameters
  444. except ValueError:
  445. # Error on builtin C function
  446. return []
  447. args = [
  448. key
  449. for key, param in params.items()
  450. if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
  451. ]
  452. if varargs:
  453. varargs = [
  454. param.name
  455. for param in params.values()
  456. if param.kind == param.VAR_POSITIONAL
  457. ]
  458. if len(varargs) == 0:
  459. varargs = None
  460. return args, varargs
  461. else:
  462. return args
  463. def _get_func_name(func):
  464. """Get function full name.
  465. Parameters
  466. ----------
  467. func : callable
  468. The function object.
  469. Returns
  470. -------
  471. name : str
  472. The function name.
  473. """
  474. parts = []
  475. module = inspect.getmodule(func)
  476. if module:
  477. parts.append(module.__name__)
  478. qualname = func.__qualname__
  479. if qualname != func.__name__:
  480. parts.append(qualname[: qualname.find(".")])
  481. parts.append(func.__name__)
  482. return ".".join(parts)
  483. def check_docstring_parameters(func, doc=None, ignore=None):
  484. """Helper to check docstring.
  485. Parameters
  486. ----------
  487. func : callable
  488. The function object to test.
  489. doc : str, default=None
  490. Docstring if it is passed manually to the test.
  491. ignore : list, default=None
  492. Parameters to ignore.
  493. Returns
  494. -------
  495. incorrect : list
  496. A list of string describing the incorrect results.
  497. """
  498. from numpydoc import docscrape
  499. incorrect = []
  500. ignore = [] if ignore is None else ignore
  501. func_name = _get_func_name(func)
  502. if not func_name.startswith("sklearn.") or func_name.startswith(
  503. "sklearn.externals"
  504. ):
  505. return incorrect
  506. # Don't check docstring for property-functions
  507. if inspect.isdatadescriptor(func):
  508. return incorrect
  509. # Don't check docstring for setup / teardown pytest functions
  510. if func_name.split(".")[-1] in ("setup_module", "teardown_module"):
  511. return incorrect
  512. # Dont check estimator_checks module
  513. if func_name.split(".")[2] == "estimator_checks":
  514. return incorrect
  515. # Get the arguments from the function signature
  516. param_signature = list(filter(lambda x: x not in ignore, _get_args(func)))
  517. # drop self
  518. if len(param_signature) > 0 and param_signature[0] == "self":
  519. param_signature.remove("self")
  520. # Analyze function's docstring
  521. if doc is None:
  522. records = []
  523. with warnings.catch_warnings(record=True):
  524. warnings.simplefilter("error", UserWarning)
  525. try:
  526. doc = docscrape.FunctionDoc(func)
  527. except UserWarning as exp:
  528. if "potentially wrong underline length" in str(exp):
  529. # Catch warning raised as of numpydoc 1.2 when
  530. # the underline length for a section of a docstring
  531. # is not consistent.
  532. message = str(exp).split("\n")[:3]
  533. incorrect += [f"In function: {func_name}"] + message
  534. return incorrect
  535. records.append(str(exp))
  536. except Exception as exp:
  537. incorrect += [func_name + " parsing error: " + str(exp)]
  538. return incorrect
  539. if len(records):
  540. raise RuntimeError("Error for %s:\n%s" % (func_name, records[0]))
  541. param_docs = []
  542. for name, type_definition, param_doc in doc["Parameters"]:
  543. # Type hints are empty only if parameter name ended with :
  544. if not type_definition.strip():
  545. if ":" in name and name[: name.index(":")][-1:].strip():
  546. incorrect += [
  547. func_name
  548. + " There was no space between the param name and colon (%r)" % name
  549. ]
  550. elif name.rstrip().endswith(":"):
  551. incorrect += [
  552. func_name
  553. + " Parameter %r has an empty type spec. Remove the colon"
  554. % (name.lstrip())
  555. ]
  556. # Create a list of parameters to compare with the parameters gotten
  557. # from the func signature
  558. if "*" not in name:
  559. param_docs.append(name.split(":")[0].strip("` "))
  560. # If one of the docstring's parameters had an error then return that
  561. # incorrect message
  562. if len(incorrect) > 0:
  563. return incorrect
  564. # Remove the parameters that should be ignored from list
  565. param_docs = list(filter(lambda x: x not in ignore, param_docs))
  566. # The following is derived from pytest, Copyright (c) 2004-2017 Holger
  567. # Krekel and others, Licensed under MIT License. See
  568. # https://github.com/pytest-dev/pytest
  569. message = []
  570. for i in range(min(len(param_docs), len(param_signature))):
  571. if param_signature[i] != param_docs[i]:
  572. message += [
  573. "There's a parameter name mismatch in function"
  574. " docstring w.r.t. function signature, at index %s"
  575. " diff: %r != %r" % (i, param_signature[i], param_docs[i])
  576. ]
  577. break
  578. if len(param_signature) > len(param_docs):
  579. message += [
  580. "Parameters in function docstring have less items w.r.t."
  581. " function signature, first missing item: %s"
  582. % param_signature[len(param_docs)]
  583. ]
  584. elif len(param_signature) < len(param_docs):
  585. message += [
  586. "Parameters in function docstring have more items w.r.t."
  587. " function signature, first extra item: %s"
  588. % param_docs[len(param_signature)]
  589. ]
  590. # If there wasn't any difference in the parameters themselves between
  591. # docstring and signature including having the same length then return
  592. # empty list
  593. if len(message) == 0:
  594. return []
  595. import difflib
  596. import pprint
  597. param_docs_formatted = pprint.pformat(param_docs).splitlines()
  598. param_signature_formatted = pprint.pformat(param_signature).splitlines()
  599. message += ["Full diff:"]
  600. message.extend(
  601. line.strip()
  602. for line in difflib.ndiff(param_signature_formatted, param_docs_formatted)
  603. )
  604. incorrect.extend(message)
  605. # Prepend function name
  606. incorrect = ["In function: " + func_name] + incorrect
  607. return incorrect
  608. def assert_run_python_script(source_code, timeout=60):
  609. """Utility to check assertions in an independent Python subprocess.
  610. The script provided in the source code should return 0 and not print
  611. anything on stderr or stdout.
  612. This is a port from cloudpickle https://github.com/cloudpipe/cloudpickle
  613. Parameters
  614. ----------
  615. source_code : str
  616. The Python source code to execute.
  617. timeout : int, default=60
  618. Time in seconds before timeout.
  619. """
  620. fd, source_file = tempfile.mkstemp(suffix="_src_test_sklearn.py")
  621. os.close(fd)
  622. try:
  623. with open(source_file, "wb") as f:
  624. f.write(source_code.encode("utf-8"))
  625. cmd = [sys.executable, source_file]
  626. cwd = op.normpath(op.join(op.dirname(sklearn.__file__), ".."))
  627. env = os.environ.copy()
  628. try:
  629. env["PYTHONPATH"] = os.pathsep.join([cwd, env["PYTHONPATH"]])
  630. except KeyError:
  631. env["PYTHONPATH"] = cwd
  632. kwargs = {"cwd": cwd, "stderr": STDOUT, "env": env}
  633. # If coverage is running, pass the config file to the subprocess
  634. coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
  635. if coverage_rc:
  636. kwargs["env"]["COVERAGE_PROCESS_START"] = coverage_rc
  637. kwargs["timeout"] = timeout
  638. try:
  639. try:
  640. out = check_output(cmd, **kwargs)
  641. except CalledProcessError as e:
  642. raise RuntimeError(
  643. "script errored with output:\n%s" % e.output.decode("utf-8")
  644. )
  645. if out != b"":
  646. raise AssertionError(out.decode("utf-8"))
  647. except TimeoutExpired as e:
  648. raise RuntimeError(
  649. "script timeout, output so far:\n%s" % e.output.decode("utf-8")
  650. )
  651. finally:
  652. os.unlink(source_file)
  653. def _convert_container(container, constructor_name, columns_name=None, dtype=None):
  654. """Convert a given container to a specific array-like with a dtype.
  655. Parameters
  656. ----------
  657. container : array-like
  658. The container to convert.
  659. constructor_name : {"list", "tuple", "array", "sparse", "dataframe", \
  660. "series", "index", "slice", "sparse_csr", "sparse_csc"}
  661. The type of the returned container.
  662. columns_name : index or array-like, default=None
  663. For pandas container supporting `columns_names`, it will affect
  664. specific names.
  665. dtype : dtype, default=None
  666. Force the dtype of the container. Does not apply to `"slice"`
  667. container.
  668. Returns
  669. -------
  670. converted_container
  671. """
  672. if constructor_name == "list":
  673. if dtype is None:
  674. return list(container)
  675. else:
  676. return np.asarray(container, dtype=dtype).tolist()
  677. elif constructor_name == "tuple":
  678. if dtype is None:
  679. return tuple(container)
  680. else:
  681. return tuple(np.asarray(container, dtype=dtype).tolist())
  682. elif constructor_name == "array":
  683. return np.asarray(container, dtype=dtype)
  684. elif constructor_name == "sparse":
  685. return sp.sparse.csr_matrix(container, dtype=dtype)
  686. elif constructor_name == "dataframe":
  687. pd = pytest.importorskip("pandas")
  688. return pd.DataFrame(container, columns=columns_name, dtype=dtype, copy=False)
  689. elif constructor_name == "series":
  690. pd = pytest.importorskip("pandas")
  691. return pd.Series(container, dtype=dtype)
  692. elif constructor_name == "index":
  693. pd = pytest.importorskip("pandas")
  694. return pd.Index(container, dtype=dtype)
  695. elif constructor_name == "slice":
  696. return slice(container[0], container[1])
  697. elif constructor_name == "sparse_csr":
  698. return sp.sparse.csr_matrix(container, dtype=dtype)
  699. elif constructor_name == "sparse_csc":
  700. return sp.sparse.csc_matrix(container, dtype=dtype)
  701. def raises(expected_exc_type, match=None, may_pass=False, err_msg=None):
  702. """Context manager to ensure exceptions are raised within a code block.
  703. This is similar to and inspired from pytest.raises, but supports a few
  704. other cases.
  705. This is only intended to be used in estimator_checks.py where we don't
  706. want to use pytest. In the rest of the code base, just use pytest.raises
  707. instead.
  708. Parameters
  709. ----------
  710. excepted_exc_type : Exception or list of Exception
  711. The exception that should be raised by the block. If a list, the block
  712. should raise one of the exceptions.
  713. match : str or list of str, default=None
  714. A regex that the exception message should match. If a list, one of
  715. the entries must match. If None, match isn't enforced.
  716. may_pass : bool, default=False
  717. If True, the block is allowed to not raise an exception. Useful in
  718. cases where some estimators may support a feature but others must
  719. fail with an appropriate error message. By default, the context
  720. manager will raise an exception if the block does not raise an
  721. exception.
  722. err_msg : str, default=None
  723. If the context manager fails (e.g. the block fails to raise the
  724. proper exception, or fails to match), then an AssertionError is
  725. raised with this message. By default, an AssertionError is raised
  726. with a default error message (depends on the kind of failure). Use
  727. this to indicate how users should fix their estimators to pass the
  728. checks.
  729. Attributes
  730. ----------
  731. raised_and_matched : bool
  732. True if an exception was raised and a match was found, False otherwise.
  733. """
  734. return _Raises(expected_exc_type, match, may_pass, err_msg)
  735. class _Raises(contextlib.AbstractContextManager):
  736. # see raises() for parameters
  737. def __init__(self, expected_exc_type, match, may_pass, err_msg):
  738. self.expected_exc_types = (
  739. expected_exc_type
  740. if isinstance(expected_exc_type, Iterable)
  741. else [expected_exc_type]
  742. )
  743. self.matches = [match] if isinstance(match, str) else match
  744. self.may_pass = may_pass
  745. self.err_msg = err_msg
  746. self.raised_and_matched = False
  747. def __exit__(self, exc_type, exc_value, _):
  748. # see
  749. # https://docs.python.org/2.5/whatsnew/pep-343.html#SECTION000910000000000000000
  750. if exc_type is None: # No exception was raised in the block
  751. if self.may_pass:
  752. return True # CM is happy
  753. else:
  754. err_msg = self.err_msg or f"Did not raise: {self.expected_exc_types}"
  755. raise AssertionError(err_msg)
  756. if not any(
  757. issubclass(exc_type, expected_type)
  758. for expected_type in self.expected_exc_types
  759. ):
  760. if self.err_msg is not None:
  761. raise AssertionError(self.err_msg) from exc_value
  762. else:
  763. return False # will re-raise the original exception
  764. if self.matches is not None:
  765. err_msg = self.err_msg or (
  766. "The error message should contain one of the following "
  767. "patterns:\n{}\nGot {}".format("\n".join(self.matches), str(exc_value))
  768. )
  769. if not any(re.search(match, str(exc_value)) for match in self.matches):
  770. raise AssertionError(err_msg) from exc_value
  771. self.raised_and_matched = True
  772. return True
  773. class MinimalClassifier:
  774. """Minimal classifier implementation with inheriting from BaseEstimator.
  775. This estimator should be tested with:
  776. * `check_estimator` in `test_estimator_checks.py`;
  777. * within a `Pipeline` in `test_pipeline.py`;
  778. * within a `SearchCV` in `test_search.py`.
  779. """
  780. _estimator_type = "classifier"
  781. def __init__(self, param=None):
  782. self.param = param
  783. def get_params(self, deep=True):
  784. return {"param": self.param}
  785. def set_params(self, **params):
  786. for key, value in params.items():
  787. setattr(self, key, value)
  788. return self
  789. def fit(self, X, y):
  790. X, y = check_X_y(X, y)
  791. check_classification_targets(y)
  792. self.classes_, counts = np.unique(y, return_counts=True)
  793. self._most_frequent_class_idx = counts.argmax()
  794. return self
  795. def predict_proba(self, X):
  796. check_is_fitted(self)
  797. X = check_array(X)
  798. proba_shape = (X.shape[0], self.classes_.size)
  799. y_proba = np.zeros(shape=proba_shape, dtype=np.float64)
  800. y_proba[:, self._most_frequent_class_idx] = 1.0
  801. return y_proba
  802. def predict(self, X):
  803. y_proba = self.predict_proba(X)
  804. y_pred = y_proba.argmax(axis=1)
  805. return self.classes_[y_pred]
  806. def score(self, X, y):
  807. from sklearn.metrics import accuracy_score
  808. return accuracy_score(y, self.predict(X))
  809. class MinimalRegressor:
  810. """Minimal regressor implementation with inheriting from BaseEstimator.
  811. This estimator should be tested with:
  812. * `check_estimator` in `test_estimator_checks.py`;
  813. * within a `Pipeline` in `test_pipeline.py`;
  814. * within a `SearchCV` in `test_search.py`.
  815. """
  816. _estimator_type = "regressor"
  817. def __init__(self, param=None):
  818. self.param = param
  819. def get_params(self, deep=True):
  820. return {"param": self.param}
  821. def set_params(self, **params):
  822. for key, value in params.items():
  823. setattr(self, key, value)
  824. return self
  825. def fit(self, X, y):
  826. X, y = check_X_y(X, y)
  827. self.is_fitted_ = True
  828. self._mean = np.mean(y)
  829. return self
  830. def predict(self, X):
  831. check_is_fitted(self)
  832. X = check_array(X)
  833. return np.ones(shape=(X.shape[0],)) * self._mean
  834. def score(self, X, y):
  835. from sklearn.metrics import r2_score
  836. return r2_score(y, self.predict(X))
  837. class MinimalTransformer:
  838. """Minimal transformer implementation with inheriting from
  839. BaseEstimator.
  840. This estimator should be tested with:
  841. * `check_estimator` in `test_estimator_checks.py`;
  842. * within a `Pipeline` in `test_pipeline.py`;
  843. * within a `SearchCV` in `test_search.py`.
  844. """
  845. def __init__(self, param=None):
  846. self.param = param
  847. def get_params(self, deep=True):
  848. return {"param": self.param}
  849. def set_params(self, **params):
  850. for key, value in params.items():
  851. setattr(self, key, value)
  852. return self
  853. def fit(self, X, y=None):
  854. check_array(X)
  855. self.is_fitted_ = True
  856. return self
  857. def transform(self, X, y=None):
  858. check_is_fitted(self)
  859. X = check_array(X)
  860. return X
  861. def fit_transform(self, X, y=None):
  862. return self.fit(X, y).transform(X, y)