| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- """Module that customize joblib tools for scikit-learn usage."""
- import functools
- import warnings
- from functools import update_wrapper
- import joblib
- from .._config import config_context, get_config
- def _with_config(delayed_func, config):
- """Helper function that intends to attach a config to a delayed function."""
- if hasattr(delayed_func, "with_config"):
- return delayed_func.with_config(config)
- else:
- warnings.warn(
- (
- "`sklearn.utils.parallel.Parallel` needs to be used in "
- "conjunction with `sklearn.utils.parallel.delayed` instead of "
- "`joblib.delayed` to correctly propagate the scikit-learn "
- "configuration to the joblib workers."
- ),
- UserWarning,
- )
- return delayed_func
- class Parallel(joblib.Parallel):
- """Tweak of :class:`joblib.Parallel` that propagates the scikit-learn configuration.
- This subclass of :class:`joblib.Parallel` ensures that the active configuration
- (thread-local) of scikit-learn is propagated to the parallel workers for the
- duration of the execution of the parallel tasks.
- The API does not change and you can refer to :class:`joblib.Parallel`
- documentation for more details.
- .. versionadded:: 1.3
- """
- def __call__(self, iterable):
- """Dispatch the tasks and return the results.
- Parameters
- ----------
- iterable : iterable
- Iterable containing tuples of (delayed_function, args, kwargs) that should
- be consumed.
- Returns
- -------
- results : list
- List of results of the tasks.
- """
- # Capture the thread-local scikit-learn configuration at the time
- # Parallel.__call__ is issued since the tasks can be dispatched
- # in a different thread depending on the backend and on the value of
- # pre_dispatch and n_jobs.
- config = get_config()
- iterable_with_config = (
- (_with_config(delayed_func, config), args, kwargs)
- for delayed_func, args, kwargs in iterable
- )
- return super().__call__(iterable_with_config)
- # remove when https://github.com/joblib/joblib/issues/1071 is fixed
- def delayed(function):
- """Decorator used to capture the arguments of a function.
- This alternative to `joblib.delayed` is meant to be used in conjunction
- with `sklearn.utils.parallel.Parallel`. The latter captures the the scikit-
- learn configuration by calling `sklearn.get_config()` in the current
- thread, prior to dispatching the first task. The captured configuration is
- then propagated and enabled for the duration of the execution of the
- delayed function in the joblib workers.
- .. versionchanged:: 1.3
- `delayed` was moved from `sklearn.utils.fixes` to `sklearn.utils.parallel`
- in scikit-learn 1.3.
- Parameters
- ----------
- function : callable
- The function to be delayed.
- Returns
- -------
- output: tuple
- Tuple containing the delayed function, the positional arguments, and the
- keyword arguments.
- """
- @functools.wraps(function)
- def delayed_function(*args, **kwargs):
- return _FuncWrapper(function), args, kwargs
- return delayed_function
- class _FuncWrapper:
- """Load the global configuration before calling the function."""
- def __init__(self, function):
- self.function = function
- update_wrapper(self, self.function)
- def with_config(self, config):
- self.config = config
- return self
- def __call__(self, *args, **kwargs):
- config = getattr(self, "config", None)
- if config is None:
- warnings.warn(
- (
- "`sklearn.utils.parallel.delayed` should be used with"
- " `sklearn.utils.parallel.Parallel` to make it possible to"
- " propagate the scikit-learn configuration of the current thread to"
- " the joblib workers."
- ),
- UserWarning,
- )
- config = {}
- with config_context(**config):
- return self.function(*args, **kwargs)
|