parallel.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """Module that customize joblib tools for scikit-learn usage."""
  2. import functools
  3. import warnings
  4. from functools import update_wrapper
  5. import joblib
  6. from .._config import config_context, get_config
  7. def _with_config(delayed_func, config):
  8. """Helper function that intends to attach a config to a delayed function."""
  9. if hasattr(delayed_func, "with_config"):
  10. return delayed_func.with_config(config)
  11. else:
  12. warnings.warn(
  13. (
  14. "`sklearn.utils.parallel.Parallel` needs to be used in "
  15. "conjunction with `sklearn.utils.parallel.delayed` instead of "
  16. "`joblib.delayed` to correctly propagate the scikit-learn "
  17. "configuration to the joblib workers."
  18. ),
  19. UserWarning,
  20. )
  21. return delayed_func
  22. class Parallel(joblib.Parallel):
  23. """Tweak of :class:`joblib.Parallel` that propagates the scikit-learn configuration.
  24. This subclass of :class:`joblib.Parallel` ensures that the active configuration
  25. (thread-local) of scikit-learn is propagated to the parallel workers for the
  26. duration of the execution of the parallel tasks.
  27. The API does not change and you can refer to :class:`joblib.Parallel`
  28. documentation for more details.
  29. .. versionadded:: 1.3
  30. """
  31. def __call__(self, iterable):
  32. """Dispatch the tasks and return the results.
  33. Parameters
  34. ----------
  35. iterable : iterable
  36. Iterable containing tuples of (delayed_function, args, kwargs) that should
  37. be consumed.
  38. Returns
  39. -------
  40. results : list
  41. List of results of the tasks.
  42. """
  43. # Capture the thread-local scikit-learn configuration at the time
  44. # Parallel.__call__ is issued since the tasks can be dispatched
  45. # in a different thread depending on the backend and on the value of
  46. # pre_dispatch and n_jobs.
  47. config = get_config()
  48. iterable_with_config = (
  49. (_with_config(delayed_func, config), args, kwargs)
  50. for delayed_func, args, kwargs in iterable
  51. )
  52. return super().__call__(iterable_with_config)
  53. # remove when https://github.com/joblib/joblib/issues/1071 is fixed
  54. def delayed(function):
  55. """Decorator used to capture the arguments of a function.
  56. This alternative to `joblib.delayed` is meant to be used in conjunction
  57. with `sklearn.utils.parallel.Parallel`. The latter captures the the scikit-
  58. learn configuration by calling `sklearn.get_config()` in the current
  59. thread, prior to dispatching the first task. The captured configuration is
  60. then propagated and enabled for the duration of the execution of the
  61. delayed function in the joblib workers.
  62. .. versionchanged:: 1.3
  63. `delayed` was moved from `sklearn.utils.fixes` to `sklearn.utils.parallel`
  64. in scikit-learn 1.3.
  65. Parameters
  66. ----------
  67. function : callable
  68. The function to be delayed.
  69. Returns
  70. -------
  71. output: tuple
  72. Tuple containing the delayed function, the positional arguments, and the
  73. keyword arguments.
  74. """
  75. @functools.wraps(function)
  76. def delayed_function(*args, **kwargs):
  77. return _FuncWrapper(function), args, kwargs
  78. return delayed_function
  79. class _FuncWrapper:
  80. """Load the global configuration before calling the function."""
  81. def __init__(self, function):
  82. self.function = function
  83. update_wrapper(self, self.function)
  84. def with_config(self, config):
  85. self.config = config
  86. return self
  87. def __call__(self, *args, **kwargs):
  88. config = getattr(self, "config", None)
  89. if config is None:
  90. warnings.warn(
  91. (
  92. "`sklearn.utils.parallel.delayed` should be used with"
  93. " `sklearn.utils.parallel.Parallel` to make it possible to"
  94. " propagate the scikit-learn configuration of the current thread to"
  95. " the joblib workers."
  96. ),
  97. UserWarning,
  98. )
  99. config = {}
  100. with config_context(**config):
  101. return self.function(*args, **kwargs)