| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import builtins
- import time
- from concurrent.futures import ThreadPoolExecutor
- import pytest
- import sklearn
- from sklearn import config_context, get_config, set_config
- from sklearn.utils.parallel import Parallel, delayed
- def test_config_context():
- assert get_config() == {
- "assume_finite": False,
- "working_memory": 1024,
- "print_changed_only": True,
- "display": "diagram",
- "array_api_dispatch": False,
- "pairwise_dist_chunk_size": 256,
- "enable_cython_pairwise_dist": True,
- "transform_output": "default",
- "enable_metadata_routing": False,
- "skip_parameter_validation": False,
- }
- # Not using as a context manager affects nothing
- config_context(assume_finite=True)
- assert get_config()["assume_finite"] is False
- with config_context(assume_finite=True):
- assert get_config() == {
- "assume_finite": True,
- "working_memory": 1024,
- "print_changed_only": True,
- "display": "diagram",
- "array_api_dispatch": False,
- "pairwise_dist_chunk_size": 256,
- "enable_cython_pairwise_dist": True,
- "transform_output": "default",
- "enable_metadata_routing": False,
- "skip_parameter_validation": False,
- }
- assert get_config()["assume_finite"] is False
- with config_context(assume_finite=True):
- with config_context(assume_finite=None):
- assert get_config()["assume_finite"] is True
- assert get_config()["assume_finite"] is True
- with config_context(assume_finite=False):
- assert get_config()["assume_finite"] is False
- with config_context(assume_finite=None):
- assert get_config()["assume_finite"] is False
- # global setting will not be retained outside of context that
- # did not modify this setting
- set_config(assume_finite=True)
- assert get_config()["assume_finite"] is True
- assert get_config()["assume_finite"] is False
- assert get_config()["assume_finite"] is True
- assert get_config() == {
- "assume_finite": False,
- "working_memory": 1024,
- "print_changed_only": True,
- "display": "diagram",
- "array_api_dispatch": False,
- "pairwise_dist_chunk_size": 256,
- "enable_cython_pairwise_dist": True,
- "transform_output": "default",
- "enable_metadata_routing": False,
- "skip_parameter_validation": False,
- }
- # No positional arguments
- with pytest.raises(TypeError):
- config_context(True)
- # No unknown arguments
- with pytest.raises(TypeError):
- config_context(do_something_else=True).__enter__()
- def test_config_context_exception():
- assert get_config()["assume_finite"] is False
- try:
- with config_context(assume_finite=True):
- assert get_config()["assume_finite"] is True
- raise ValueError()
- except ValueError:
- pass
- assert get_config()["assume_finite"] is False
- def test_set_config():
- assert get_config()["assume_finite"] is False
- set_config(assume_finite=None)
- assert get_config()["assume_finite"] is False
- set_config(assume_finite=True)
- assert get_config()["assume_finite"] is True
- set_config(assume_finite=None)
- assert get_config()["assume_finite"] is True
- set_config(assume_finite=False)
- assert get_config()["assume_finite"] is False
- # No unknown arguments
- with pytest.raises(TypeError):
- set_config(do_something_else=True)
- def set_assume_finite(assume_finite, sleep_duration):
- """Return the value of assume_finite after waiting `sleep_duration`."""
- with config_context(assume_finite=assume_finite):
- time.sleep(sleep_duration)
- return get_config()["assume_finite"]
- @pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"])
- def test_config_threadsafe_joblib(backend):
- """Test that the global config is threadsafe with all joblib backends.
- Two jobs are spawned and sets assume_finite to two different values.
- When the job with a duration 0.1s completes, the assume_finite value
- should be the same as the value passed to the function. In other words,
- it is not influenced by the other job setting assume_finite to True.
- """
- assume_finites = [False, True, False, True]
- sleep_durations = [0.1, 0.2, 0.1, 0.2]
- items = Parallel(backend=backend, n_jobs=2)(
- delayed(set_assume_finite)(assume_finite, sleep_dur)
- for assume_finite, sleep_dur in zip(assume_finites, sleep_durations)
- )
- assert items == [False, True, False, True]
- def test_config_threadsafe():
- """Uses threads directly to test that the global config does not change
- between threads. Same test as `test_config_threadsafe_joblib` but with
- `ThreadPoolExecutor`."""
- assume_finites = [False, True, False, True]
- sleep_durations = [0.1, 0.2, 0.1, 0.2]
- with ThreadPoolExecutor(max_workers=2) as e:
- items = [
- output
- for output in e.map(set_assume_finite, assume_finites, sleep_durations)
- ]
- assert items == [False, True, False, True]
- def test_config_array_api_dispatch_error(monkeypatch):
- """Check error is raised when array_api_compat is not installed."""
- # Hide array_api_compat import
- orig_import = builtins.__import__
- def mocked_import(name, *args, **kwargs):
- if name == "array_api_compat":
- raise ImportError
- return orig_import(name, *args, **kwargs)
- monkeypatch.setattr(builtins, "__import__", mocked_import)
- with pytest.raises(ImportError, match="array_api_compat is required"):
- with config_context(array_api_dispatch=True):
- pass
- with pytest.raises(ImportError, match="array_api_compat is required"):
- set_config(array_api_dispatch=True)
- def test_config_array_api_dispatch_error_numpy(monkeypatch):
- """Check error when NumPy is too old"""
- # Pretend that array_api_compat is installed.
- orig_import = builtins.__import__
- def mocked_import(name, *args, **kwargs):
- if name == "array_api_compat":
- return object()
- return orig_import(name, *args, **kwargs)
- monkeypatch.setattr(builtins, "__import__", mocked_import)
- monkeypatch.setattr(sklearn.utils._array_api.numpy, "__version__", "1.20")
- with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
- with config_context(array_api_dispatch=True):
- pass
- with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
- set_config(array_api_dispatch=True)
|