test_common.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """Test loaders for common functionality."""
  2. import inspect
  3. import os
  4. import numpy as np
  5. import pytest
  6. import sklearn.datasets
  7. def is_pillow_installed():
  8. try:
  9. import PIL # noqa
  10. return True
  11. except ImportError:
  12. return False
  13. FETCH_PYTEST_MARKERS = {
  14. "return_X_y": {
  15. "fetch_20newsgroups": pytest.mark.xfail(
  16. reason="X is a list and does not have a shape argument"
  17. ),
  18. "fetch_openml": pytest.mark.xfail(
  19. reason="fetch_opeml requires a dataset name or id"
  20. ),
  21. "fetch_lfw_people": pytest.mark.skipif(
  22. not is_pillow_installed(), reason="pillow is not installed"
  23. ),
  24. },
  25. "as_frame": {
  26. "fetch_openml": pytest.mark.xfail(
  27. reason="fetch_opeml requires a dataset name or id"
  28. ),
  29. },
  30. }
  31. def check_pandas_dependency_message(fetch_func):
  32. try:
  33. import pandas # noqa
  34. pytest.skip("This test requires pandas to not be installed")
  35. except ImportError:
  36. # Check that pandas is imported lazily and that an informative error
  37. # message is raised when pandas is missing:
  38. name = fetch_func.__name__
  39. expected_msg = f"{name} with as_frame=True requires pandas"
  40. with pytest.raises(ImportError, match=expected_msg):
  41. fetch_func(as_frame=True)
  42. def check_return_X_y(bunch, dataset_func):
  43. X_y_tuple = dataset_func(return_X_y=True)
  44. assert isinstance(X_y_tuple, tuple)
  45. assert X_y_tuple[0].shape == bunch.data.shape
  46. assert X_y_tuple[1].shape == bunch.target.shape
  47. def check_as_frame(
  48. bunch, dataset_func, expected_data_dtype=None, expected_target_dtype=None
  49. ):
  50. pd = pytest.importorskip("pandas")
  51. frame_bunch = dataset_func(as_frame=True)
  52. assert hasattr(frame_bunch, "frame")
  53. assert isinstance(frame_bunch.frame, pd.DataFrame)
  54. assert isinstance(frame_bunch.data, pd.DataFrame)
  55. assert frame_bunch.data.shape == bunch.data.shape
  56. if frame_bunch.target.ndim > 1:
  57. assert isinstance(frame_bunch.target, pd.DataFrame)
  58. else:
  59. assert isinstance(frame_bunch.target, pd.Series)
  60. assert frame_bunch.target.shape[0] == bunch.target.shape[0]
  61. if expected_data_dtype is not None:
  62. assert np.all(frame_bunch.data.dtypes == expected_data_dtype)
  63. if expected_target_dtype is not None:
  64. assert np.all(frame_bunch.target.dtypes == expected_target_dtype)
  65. # Test for return_X_y and as_frame=True
  66. frame_X, frame_y = dataset_func(as_frame=True, return_X_y=True)
  67. assert isinstance(frame_X, pd.DataFrame)
  68. if frame_y.ndim > 1:
  69. assert isinstance(frame_X, pd.DataFrame)
  70. else:
  71. assert isinstance(frame_y, pd.Series)
  72. def _skip_network_tests():
  73. return os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "1"
  74. def _generate_func_supporting_param(param, dataset_type=("load", "fetch")):
  75. markers_fetch = FETCH_PYTEST_MARKERS.get(param, {})
  76. for name, obj in inspect.getmembers(sklearn.datasets):
  77. if not inspect.isfunction(obj):
  78. continue
  79. is_dataset_type = any([name.startswith(t) for t in dataset_type])
  80. is_support_param = param in inspect.signature(obj).parameters
  81. if is_dataset_type and is_support_param:
  82. # check if we should skip if we don't have network support
  83. marks = [
  84. pytest.mark.skipif(
  85. condition=name.startswith("fetch") and _skip_network_tests(),
  86. reason="Skip because fetcher requires internet network",
  87. )
  88. ]
  89. if name in markers_fetch:
  90. marks.append(markers_fetch[name])
  91. yield pytest.param(name, obj, marks=marks)
  92. @pytest.mark.parametrize(
  93. "name, dataset_func", _generate_func_supporting_param("return_X_y")
  94. )
  95. def test_common_check_return_X_y(name, dataset_func):
  96. bunch = dataset_func()
  97. check_return_X_y(bunch, dataset_func)
  98. @pytest.mark.parametrize(
  99. "name, dataset_func", _generate_func_supporting_param("as_frame")
  100. )
  101. def test_common_check_as_frame(name, dataset_func):
  102. bunch = dataset_func()
  103. check_as_frame(bunch, dataset_func)
  104. @pytest.mark.parametrize(
  105. "name, dataset_func", _generate_func_supporting_param("as_frame")
  106. )
  107. def test_common_check_pandas_dependency(name, dataset_func):
  108. check_pandas_dependency_message(dataset_func)