conftest.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import platform
  2. import sys
  3. from contextlib import suppress
  4. from functools import wraps
  5. from os import environ
  6. from unittest import SkipTest
  7. import joblib
  8. import numpy as np
  9. import pytest
  10. from _pytest.doctest import DoctestItem
  11. from threadpoolctl import threadpool_limits
  12. from sklearn._min_dependencies import PYTEST_MIN_VERSION
  13. from sklearn.datasets import (
  14. fetch_20newsgroups,
  15. fetch_20newsgroups_vectorized,
  16. fetch_california_housing,
  17. fetch_covtype,
  18. fetch_kddcup99,
  19. fetch_olivetti_faces,
  20. fetch_rcv1,
  21. )
  22. from sklearn.tests import random_seed
  23. from sklearn.utils import _IS_32BIT
  24. from sklearn.utils.fixes import np_base_version, parse_version, sp_version
  25. if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
  26. raise ImportError(
  27. f"Your version of pytest is too old. Got version {pytest.__version__}, you"
  28. f" should have pytest >= {PYTEST_MIN_VERSION} installed."
  29. )
  30. scipy_datasets_require_network = sp_version >= parse_version("1.10")
  31. def raccoon_face_or_skip():
  32. # SciPy >= 1.10 requires network to access to get data
  33. if scipy_datasets_require_network:
  34. run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
  35. if not run_network_tests:
  36. raise SkipTest("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
  37. try:
  38. import pooch # noqa
  39. except ImportError:
  40. raise SkipTest("test requires pooch to be installed")
  41. from scipy.datasets import face
  42. else:
  43. from scipy.misc import face
  44. return face(gray=True)
  45. dataset_fetchers = {
  46. "fetch_20newsgroups_fxt": fetch_20newsgroups,
  47. "fetch_20newsgroups_vectorized_fxt": fetch_20newsgroups_vectorized,
  48. "fetch_california_housing_fxt": fetch_california_housing,
  49. "fetch_covtype_fxt": fetch_covtype,
  50. "fetch_kddcup99_fxt": fetch_kddcup99,
  51. "fetch_olivetti_faces_fxt": fetch_olivetti_faces,
  52. "fetch_rcv1_fxt": fetch_rcv1,
  53. }
  54. if scipy_datasets_require_network:
  55. dataset_fetchers["raccoon_face_fxt"] = raccoon_face_or_skip
  56. _SKIP32_MARK = pytest.mark.skipif(
  57. environ.get("SKLEARN_RUN_FLOAT32_TESTS", "0") != "1",
  58. reason="Set SKLEARN_RUN_FLOAT32_TESTS=1 to run float32 dtype tests",
  59. )
  60. # Global fixtures
  61. @pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64])
  62. def global_dtype(request):
  63. yield request.param
  64. def _fetch_fixture(f):
  65. """Fetch dataset (download if missing and requested by environment)."""
  66. download_if_missing = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
  67. @wraps(f)
  68. def wrapped(*args, **kwargs):
  69. kwargs["download_if_missing"] = download_if_missing
  70. try:
  71. return f(*args, **kwargs)
  72. except OSError as e:
  73. if str(e) != "Data not found and `download_if_missing` is False":
  74. raise
  75. pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
  76. return pytest.fixture(lambda: wrapped)
  77. # Adds fixtures for fetching data
  78. fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups)
  79. fetch_20newsgroups_vectorized_fxt = _fetch_fixture(fetch_20newsgroups_vectorized)
  80. fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing)
  81. fetch_covtype_fxt = _fetch_fixture(fetch_covtype)
  82. fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99)
  83. fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces)
  84. fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1)
  85. raccoon_face_fxt = pytest.fixture(raccoon_face_or_skip)
  86. def pytest_collection_modifyitems(config, items):
  87. """Called after collect is completed.
  88. Parameters
  89. ----------
  90. config : pytest config
  91. items : list of collected items
  92. """
  93. run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
  94. skip_network = pytest.mark.skip(
  95. reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0"
  96. )
  97. # download datasets during collection to avoid thread unsafe behavior
  98. # when running pytest in parallel with pytest-xdist
  99. dataset_features_set = set(dataset_fetchers)
  100. datasets_to_download = set()
  101. for item in items:
  102. if not hasattr(item, "fixturenames"):
  103. continue
  104. item_fixtures = set(item.fixturenames)
  105. dataset_to_fetch = item_fixtures & dataset_features_set
  106. if not dataset_to_fetch:
  107. continue
  108. if run_network_tests:
  109. datasets_to_download |= dataset_to_fetch
  110. else:
  111. # network tests are skipped
  112. item.add_marker(skip_network)
  113. # Only download datasets on the first worker spawned by pytest-xdist
  114. # to avoid thread unsafe behavior. If pytest-xdist is not used, we still
  115. # download before tests run.
  116. worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0")
  117. if worker_id == "gw0" and run_network_tests:
  118. for name in datasets_to_download:
  119. with suppress(SkipTest):
  120. dataset_fetchers[name]()
  121. for item in items:
  122. # Known failure on with GradientBoostingClassifier on ARM64
  123. if (
  124. item.name.endswith("GradientBoostingClassifier")
  125. and platform.machine() == "aarch64"
  126. ):
  127. marker = pytest.mark.xfail(
  128. reason=(
  129. "know failure. See "
  130. "https://github.com/scikit-learn/scikit-learn/issues/17797" # noqa
  131. )
  132. )
  133. item.add_marker(marker)
  134. skip_doctests = False
  135. try:
  136. import matplotlib # noqa
  137. except ImportError:
  138. skip_doctests = True
  139. reason = "matplotlib is required to run the doctests"
  140. if _IS_32BIT:
  141. reason = "doctest are only run when the default numpy int is 64 bits."
  142. skip_doctests = True
  143. elif sys.platform.startswith("win32"):
  144. reason = (
  145. "doctests are not run for Windows because numpy arrays "
  146. "repr is inconsistent across platforms."
  147. )
  148. skip_doctests = True
  149. if np_base_version >= parse_version("2"):
  150. reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2"
  151. skip_doctests = True
  152. # Normally doctest has the entire module's scope. Here we set globs to an empty dict
  153. # to remove the module's scope:
  154. # https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
  155. for item in items:
  156. if isinstance(item, DoctestItem):
  157. item.dtest.globs = {}
  158. if skip_doctests:
  159. skip_marker = pytest.mark.skip(reason=reason)
  160. for item in items:
  161. if isinstance(item, DoctestItem):
  162. # work-around an internal error with pytest if adding a skip
  163. # mark to a doctest in a contextmanager, see
  164. # https://github.com/pytest-dev/pytest/issues/8796 for more
  165. # details.
  166. if item.name != "sklearn._config.config_context":
  167. item.add_marker(skip_marker)
  168. try:
  169. import PIL # noqa
  170. pillow_installed = True
  171. except ImportError:
  172. pillow_installed = False
  173. if not pillow_installed:
  174. skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!")
  175. for item in items:
  176. if item.name in [
  177. "sklearn.feature_extraction.image.PatchExtractor",
  178. "sklearn.feature_extraction.image.extract_patches_2d",
  179. ]:
  180. item.add_marker(skip_marker)
  181. @pytest.fixture(scope="function")
  182. def pyplot():
  183. """Setup and teardown fixture for matplotlib.
  184. This fixture checks if we can import matplotlib. If not, the tests will be
  185. skipped. Otherwise, we close the figures before and after running the
  186. functions.
  187. Returns
  188. -------
  189. pyplot : module
  190. The ``matplotlib.pyplot`` module.
  191. """
  192. pyplot = pytest.importorskip("matplotlib.pyplot")
  193. pyplot.close("all")
  194. yield pyplot
  195. pyplot.close("all")
  196. def pytest_configure(config):
  197. # Use matplotlib agg backend during the tests including doctests
  198. try:
  199. import matplotlib
  200. matplotlib.use("agg")
  201. except ImportError:
  202. pass
  203. allowed_parallelism = joblib.cpu_count(only_physical_cores=True)
  204. xdist_worker_count = environ.get("PYTEST_XDIST_WORKER_COUNT")
  205. if xdist_worker_count is not None:
  206. # Set the number of OpenMP and BLAS threads based on the number of workers
  207. # xdist is using to prevent oversubscription.
  208. allowed_parallelism = max(allowed_parallelism // int(xdist_worker_count), 1)
  209. threadpool_limits(allowed_parallelism)
  210. # Register global_random_seed plugin if it is not already registered
  211. if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
  212. config.pluginmanager.register(random_seed)