test_config.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import builtins
  2. import time
  3. from concurrent.futures import ThreadPoolExecutor
  4. import pytest
  5. import sklearn
  6. from sklearn import config_context, get_config, set_config
  7. from sklearn.utils.parallel import Parallel, delayed
  8. def test_config_context():
  9. assert get_config() == {
  10. "assume_finite": False,
  11. "working_memory": 1024,
  12. "print_changed_only": True,
  13. "display": "diagram",
  14. "array_api_dispatch": False,
  15. "pairwise_dist_chunk_size": 256,
  16. "enable_cython_pairwise_dist": True,
  17. "transform_output": "default",
  18. "enable_metadata_routing": False,
  19. "skip_parameter_validation": False,
  20. }
  21. # Not using as a context manager affects nothing
  22. config_context(assume_finite=True)
  23. assert get_config()["assume_finite"] is False
  24. with config_context(assume_finite=True):
  25. assert get_config() == {
  26. "assume_finite": True,
  27. "working_memory": 1024,
  28. "print_changed_only": True,
  29. "display": "diagram",
  30. "array_api_dispatch": False,
  31. "pairwise_dist_chunk_size": 256,
  32. "enable_cython_pairwise_dist": True,
  33. "transform_output": "default",
  34. "enable_metadata_routing": False,
  35. "skip_parameter_validation": False,
  36. }
  37. assert get_config()["assume_finite"] is False
  38. with config_context(assume_finite=True):
  39. with config_context(assume_finite=None):
  40. assert get_config()["assume_finite"] is True
  41. assert get_config()["assume_finite"] is True
  42. with config_context(assume_finite=False):
  43. assert get_config()["assume_finite"] is False
  44. with config_context(assume_finite=None):
  45. assert get_config()["assume_finite"] is False
  46. # global setting will not be retained outside of context that
  47. # did not modify this setting
  48. set_config(assume_finite=True)
  49. assert get_config()["assume_finite"] is True
  50. assert get_config()["assume_finite"] is False
  51. assert get_config()["assume_finite"] is True
  52. assert get_config() == {
  53. "assume_finite": False,
  54. "working_memory": 1024,
  55. "print_changed_only": True,
  56. "display": "diagram",
  57. "array_api_dispatch": False,
  58. "pairwise_dist_chunk_size": 256,
  59. "enable_cython_pairwise_dist": True,
  60. "transform_output": "default",
  61. "enable_metadata_routing": False,
  62. "skip_parameter_validation": False,
  63. }
  64. # No positional arguments
  65. with pytest.raises(TypeError):
  66. config_context(True)
  67. # No unknown arguments
  68. with pytest.raises(TypeError):
  69. config_context(do_something_else=True).__enter__()
  70. def test_config_context_exception():
  71. assert get_config()["assume_finite"] is False
  72. try:
  73. with config_context(assume_finite=True):
  74. assert get_config()["assume_finite"] is True
  75. raise ValueError()
  76. except ValueError:
  77. pass
  78. assert get_config()["assume_finite"] is False
  79. def test_set_config():
  80. assert get_config()["assume_finite"] is False
  81. set_config(assume_finite=None)
  82. assert get_config()["assume_finite"] is False
  83. set_config(assume_finite=True)
  84. assert get_config()["assume_finite"] is True
  85. set_config(assume_finite=None)
  86. assert get_config()["assume_finite"] is True
  87. set_config(assume_finite=False)
  88. assert get_config()["assume_finite"] is False
  89. # No unknown arguments
  90. with pytest.raises(TypeError):
  91. set_config(do_something_else=True)
  92. def set_assume_finite(assume_finite, sleep_duration):
  93. """Return the value of assume_finite after waiting `sleep_duration`."""
  94. with config_context(assume_finite=assume_finite):
  95. time.sleep(sleep_duration)
  96. return get_config()["assume_finite"]
  97. @pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"])
  98. def test_config_threadsafe_joblib(backend):
  99. """Test that the global config is threadsafe with all joblib backends.
  100. Two jobs are spawned and sets assume_finite to two different values.
  101. When the job with a duration 0.1s completes, the assume_finite value
  102. should be the same as the value passed to the function. In other words,
  103. it is not influenced by the other job setting assume_finite to True.
  104. """
  105. assume_finites = [False, True, False, True]
  106. sleep_durations = [0.1, 0.2, 0.1, 0.2]
  107. items = Parallel(backend=backend, n_jobs=2)(
  108. delayed(set_assume_finite)(assume_finite, sleep_dur)
  109. for assume_finite, sleep_dur in zip(assume_finites, sleep_durations)
  110. )
  111. assert items == [False, True, False, True]
  112. def test_config_threadsafe():
  113. """Uses threads directly to test that the global config does not change
  114. between threads. Same test as `test_config_threadsafe_joblib` but with
  115. `ThreadPoolExecutor`."""
  116. assume_finites = [False, True, False, True]
  117. sleep_durations = [0.1, 0.2, 0.1, 0.2]
  118. with ThreadPoolExecutor(max_workers=2) as e:
  119. items = [
  120. output
  121. for output in e.map(set_assume_finite, assume_finites, sleep_durations)
  122. ]
  123. assert items == [False, True, False, True]
  124. def test_config_array_api_dispatch_error(monkeypatch):
  125. """Check error is raised when array_api_compat is not installed."""
  126. # Hide array_api_compat import
  127. orig_import = builtins.__import__
  128. def mocked_import(name, *args, **kwargs):
  129. if name == "array_api_compat":
  130. raise ImportError
  131. return orig_import(name, *args, **kwargs)
  132. monkeypatch.setattr(builtins, "__import__", mocked_import)
  133. with pytest.raises(ImportError, match="array_api_compat is required"):
  134. with config_context(array_api_dispatch=True):
  135. pass
  136. with pytest.raises(ImportError, match="array_api_compat is required"):
  137. set_config(array_api_dispatch=True)
  138. def test_config_array_api_dispatch_error_numpy(monkeypatch):
  139. """Check error when NumPy is too old"""
  140. # Pretend that array_api_compat is installed.
  141. orig_import = builtins.__import__
  142. def mocked_import(name, *args, **kwargs):
  143. if name == "array_api_compat":
  144. return object()
  145. return orig_import(name, *args, **kwargs)
  146. monkeypatch.setattr(builtins, "__import__", mocked_import)
  147. monkeypatch.setattr(sklearn.utils._array_api.numpy, "__version__", "1.20")
  148. with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
  149. with config_context(array_api_dispatch=True):
  150. pass
  151. with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
  152. set_config(array_api_dispatch=True)