test_parallel.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import time
  2. import joblib
  3. import numpy as np
  4. import pytest
  5. from numpy.testing import assert_array_equal
  6. from sklearn import config_context, get_config
  7. from sklearn.compose import make_column_transformer
  8. from sklearn.datasets import load_iris
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.model_selection import GridSearchCV
  11. from sklearn.pipeline import make_pipeline
  12. from sklearn.preprocessing import StandardScaler
  13. from sklearn.utils.parallel import Parallel, delayed
  14. def get_working_memory():
  15. return get_config()["working_memory"]
  16. @pytest.mark.parametrize("n_jobs", [1, 2])
  17. @pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"])
  18. def test_configuration_passes_through_to_joblib(n_jobs, backend):
  19. # Tests that the global global configuration is passed to joblib jobs
  20. with config_context(working_memory=123):
  21. results = Parallel(n_jobs=n_jobs, backend=backend)(
  22. delayed(get_working_memory)() for _ in range(2)
  23. )
  24. assert_array_equal(results, [123] * 2)
  25. def test_parallel_delayed_warnings():
  26. """Informative warnings should be raised when mixing sklearn and joblib API"""
  27. # We should issue a warning when one wants to use sklearn.utils.fixes.Parallel
  28. # with joblib.delayed. The config will not be propagated to the workers.
  29. warn_msg = "`sklearn.utils.parallel.Parallel` needs to be used in conjunction"
  30. with pytest.warns(UserWarning, match=warn_msg) as records:
  31. Parallel()(joblib.delayed(time.sleep)(0) for _ in range(10))
  32. assert len(records) == 10
  33. # We should issue a warning if one wants to use sklearn.utils.fixes.delayed with
  34. # joblib.Parallel
  35. warn_msg = (
  36. "`sklearn.utils.parallel.delayed` should be used with "
  37. "`sklearn.utils.parallel.Parallel` to make it possible to propagate"
  38. )
  39. with pytest.warns(UserWarning, match=warn_msg) as records:
  40. joblib.Parallel()(delayed(time.sleep)(0) for _ in range(10))
  41. assert len(records) == 10
  42. @pytest.mark.parametrize("n_jobs", [1, 2])
  43. def test_dispatch_config_parallel(n_jobs):
  44. """Check that we properly dispatch the configuration in parallel processing.
  45. Non-regression test for:
  46. https://github.com/scikit-learn/scikit-learn/issues/25239
  47. """
  48. pd = pytest.importorskip("pandas")
  49. iris = load_iris(as_frame=True)
  50. class TransformerRequiredDataFrame(StandardScaler):
  51. def fit(self, X, y=None):
  52. assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
  53. return super().fit(X, y)
  54. def transform(self, X, y=None):
  55. assert isinstance(X, pd.DataFrame), "X should be a DataFrame"
  56. return super().transform(X, y)
  57. dropper = make_column_transformer(
  58. ("drop", [0]),
  59. remainder="passthrough",
  60. n_jobs=n_jobs,
  61. )
  62. param_grid = {"randomforestclassifier__max_depth": [1, 2, 3]}
  63. search_cv = GridSearchCV(
  64. make_pipeline(
  65. dropper,
  66. TransformerRequiredDataFrame(),
  67. RandomForestClassifier(n_estimators=5, n_jobs=n_jobs),
  68. ),
  69. param_grid,
  70. cv=5,
  71. n_jobs=n_jobs,
  72. error_score="raise", # this search should not fail
  73. )
  74. # make sure that `fit` would fail in case we don't request dataframe
  75. with pytest.raises(AssertionError, match="X should be a DataFrame"):
  76. search_cv.fit(iris.data, iris.target)
  77. with config_context(transform_output="pandas"):
  78. # we expect each intermediate steps to output a DataFrame
  79. search_cv.fit(iris.data, iris.target)
  80. assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()