| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """Compatibility fixes for older version of python, numpy and scipy
- If you add content to this file, please give the version of the package
- at which the fix is no longer needed.
- """
- # Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>
- # Gael Varoquaux <gael.varoquaux@normalesup.org>
- # Fabian Pedregosa <fpedregosa@acm.org>
- # Lars Buitinck
- #
- # License: BSD 3 clause
- import sys
- from importlib import resources
- import numpy as np
- import scipy
- import scipy.sparse.linalg
- import scipy.stats
- import threadpoolctl
- import sklearn
- from ..externals._packaging.version import parse as parse_version
- from .deprecation import deprecated
- np_version = parse_version(np.__version__)
- np_base_version = parse_version(np_version.base_version)
- sp_version = parse_version(scipy.__version__)
- sp_base_version = parse_version(sp_version.base_version)
- try:
- from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2
- except ImportError: # SciPy < 1.8
- from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 # type: ignore # noqa
- def _object_dtype_isnan(X):
- return X != X
- # Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because
- # `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22.
- def _percentile(a, q, *, method="linear", **kwargs):
- return np.percentile(a, q, interpolation=method, **kwargs)
- if np_version < parse_version("1.22"):
- percentile = _percentile
- else: # >= 1.22
- from numpy import percentile # type: ignore # noqa
- # compatibility fix for threadpoolctl >= 3.0.0
- # since version 3 it's possible to setup a global threadpool controller to avoid
- # looping through all loaded shared libraries each time.
- # the global controller is created during the first call to threadpoolctl.
- def _get_threadpool_controller():
- if not hasattr(threadpoolctl, "ThreadpoolController"):
- return None
- if not hasattr(sklearn, "_sklearn_threadpool_controller"):
- sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController()
- return sklearn._sklearn_threadpool_controller
- def threadpool_limits(limits=None, user_api=None):
- controller = _get_threadpool_controller()
- if controller is not None:
- return controller.limit(limits=limits, user_api=user_api)
- else:
- return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)
- threadpool_limits.__doc__ = threadpoolctl.threadpool_limits.__doc__
- def threadpool_info():
- controller = _get_threadpool_controller()
- if controller is not None:
- return controller.info()
- else:
- return threadpoolctl.threadpool_info()
- threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__
- @deprecated(
- "The function `delayed` has been moved from `sklearn.utils.fixes` to "
- "`sklearn.utils.parallel`. This import path will be removed in 1.5."
- )
- def delayed(function):
- from sklearn.utils.parallel import delayed
- return delayed(function)
- # TODO: Remove when SciPy 1.11 is the minimum supported version
- def _mode(a, axis=0):
- if sp_version >= parse_version("1.9.0"):
- mode = scipy.stats.mode(a, axis=axis, keepdims=True)
- if sp_version >= parse_version("1.10.999"):
- # scipy.stats.mode has changed returned array shape with axis=None
- # and keepdims=True, see https://github.com/scipy/scipy/pull/17561
- if axis is None:
- mode = np.ravel(mode)
- return mode
- return scipy.stats.mode(a, axis=axis)
- # TODO: Remove when Scipy 1.12 is the minimum supported version
- if sp_base_version >= parse_version("1.12.0"):
- _sparse_linalg_cg = scipy.sparse.linalg.cg
- else:
- def _sparse_linalg_cg(A, b, **kwargs):
- if "rtol" in kwargs:
- kwargs["tol"] = kwargs.pop("rtol")
- if "atol" not in kwargs:
- kwargs["atol"] = "legacy"
- return scipy.sparse.linalg.cg(A, b, **kwargs)
- ###############################################################################
- # Backport of Python 3.9's importlib.resources
- # TODO: Remove when Python 3.9 is the minimum supported version
- def _open_text(data_module, data_file_name):
- if sys.version_info >= (3, 9):
- return resources.files(data_module).joinpath(data_file_name).open("r")
- else:
- return resources.open_text(data_module, data_file_name)
- def _open_binary(data_module, data_file_name):
- if sys.version_info >= (3, 9):
- return resources.files(data_module).joinpath(data_file_name).open("rb")
- else:
- return resources.open_binary(data_module, data_file_name)
- def _read_text(descr_module, descr_file_name):
- if sys.version_info >= (3, 9):
- return resources.files(descr_module).joinpath(descr_file_name).read_text()
- else:
- return resources.read_text(descr_module, descr_file_name)
- def _path(data_module, data_file_name):
- if sys.version_info >= (3, 9):
- return resources.as_file(resources.files(data_module).joinpath(data_file_name))
- else:
- return resources.path(data_module, data_file_name)
- def _is_resource(data_module, data_file_name):
- if sys.version_info >= (3, 9):
- return resources.files(data_module).joinpath(data_file_name).is_file()
- else:
- return resources.is_resource(data_module, data_file_name)
- def _contents(data_module):
- if sys.version_info >= (3, 9):
- return (
- resource.name
- for resource in resources.files(data_module).iterdir()
- if resource.is_file()
- )
- else:
- return resources.contents(data_module)
- # For +1.25 NumPy versions exceptions and warnings are being moved
- # to a dedicated submodule.
- if np_version >= parse_version("1.25.0"):
- from numpy.exceptions import ComplexWarning, VisibleDeprecationWarning
- else:
- from numpy import ComplexWarning, VisibleDeprecationWarning # type: ignore # noqa
- # TODO: Remove when Scipy 1.6 is the minimum supported version
- try:
- from scipy.integrate import trapezoid # type: ignore # noqa
- except ImportError:
- from scipy.integrate import trapz as trapezoid # type: ignore # noqa
|