fixes.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """Compatibility fixes for older version of python, numpy and scipy
  2. If you add content to this file, please give the version of the package
  3. at which the fix is no longer needed.
  4. """
  5. # Authors: Emmanuelle Gouillart <emmanuelle.gouillart@normalesup.org>
  6. # Gael Varoquaux <gael.varoquaux@normalesup.org>
  7. # Fabian Pedregosa <fpedregosa@acm.org>
  8. # Lars Buitinck
  9. #
  10. # License: BSD 3 clause
  11. import sys
  12. from importlib import resources
  13. import numpy as np
  14. import scipy
  15. import scipy.sparse.linalg
  16. import scipy.stats
  17. import threadpoolctl
  18. import sklearn
  19. from ..externals._packaging.version import parse as parse_version
  20. from .deprecation import deprecated
  21. np_version = parse_version(np.__version__)
  22. np_base_version = parse_version(np_version.base_version)
  23. sp_version = parse_version(scipy.__version__)
  24. sp_base_version = parse_version(sp_version.base_version)
  25. try:
  26. from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2
  27. except ImportError: # SciPy < 1.8
  28. from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1 # type: ignore # noqa
  29. def _object_dtype_isnan(X):
  30. return X != X
  31. # Rename the `method` kwarg to `interpolation` for NumPy < 1.22, because
  32. # `interpolation` kwarg was deprecated in favor of `method` in NumPy >= 1.22.
  33. def _percentile(a, q, *, method="linear", **kwargs):
  34. return np.percentile(a, q, interpolation=method, **kwargs)
  35. if np_version < parse_version("1.22"):
  36. percentile = _percentile
  37. else: # >= 1.22
  38. from numpy import percentile # type: ignore # noqa
  39. # compatibility fix for threadpoolctl >= 3.0.0
  40. # since version 3 it's possible to setup a global threadpool controller to avoid
  41. # looping through all loaded shared libraries each time.
  42. # the global controller is created during the first call to threadpoolctl.
  43. def _get_threadpool_controller():
  44. if not hasattr(threadpoolctl, "ThreadpoolController"):
  45. return None
  46. if not hasattr(sklearn, "_sklearn_threadpool_controller"):
  47. sklearn._sklearn_threadpool_controller = threadpoolctl.ThreadpoolController()
  48. return sklearn._sklearn_threadpool_controller
  49. def threadpool_limits(limits=None, user_api=None):
  50. controller = _get_threadpool_controller()
  51. if controller is not None:
  52. return controller.limit(limits=limits, user_api=user_api)
  53. else:
  54. return threadpoolctl.threadpool_limits(limits=limits, user_api=user_api)
  55. threadpool_limits.__doc__ = threadpoolctl.threadpool_limits.__doc__
  56. def threadpool_info():
  57. controller = _get_threadpool_controller()
  58. if controller is not None:
  59. return controller.info()
  60. else:
  61. return threadpoolctl.threadpool_info()
  62. threadpool_info.__doc__ = threadpoolctl.threadpool_info.__doc__
  63. @deprecated(
  64. "The function `delayed` has been moved from `sklearn.utils.fixes` to "
  65. "`sklearn.utils.parallel`. This import path will be removed in 1.5."
  66. )
  67. def delayed(function):
  68. from sklearn.utils.parallel import delayed
  69. return delayed(function)
  70. # TODO: Remove when SciPy 1.11 is the minimum supported version
  71. def _mode(a, axis=0):
  72. if sp_version >= parse_version("1.9.0"):
  73. mode = scipy.stats.mode(a, axis=axis, keepdims=True)
  74. if sp_version >= parse_version("1.10.999"):
  75. # scipy.stats.mode has changed returned array shape with axis=None
  76. # and keepdims=True, see https://github.com/scipy/scipy/pull/17561
  77. if axis is None:
  78. mode = np.ravel(mode)
  79. return mode
  80. return scipy.stats.mode(a, axis=axis)
  81. # TODO: Remove when Scipy 1.12 is the minimum supported version
  82. if sp_base_version >= parse_version("1.12.0"):
  83. _sparse_linalg_cg = scipy.sparse.linalg.cg
  84. else:
  85. def _sparse_linalg_cg(A, b, **kwargs):
  86. if "rtol" in kwargs:
  87. kwargs["tol"] = kwargs.pop("rtol")
  88. if "atol" not in kwargs:
  89. kwargs["atol"] = "legacy"
  90. return scipy.sparse.linalg.cg(A, b, **kwargs)
  91. ###############################################################################
  92. # Backport of Python 3.9's importlib.resources
  93. # TODO: Remove when Python 3.9 is the minimum supported version
  94. def _open_text(data_module, data_file_name):
  95. if sys.version_info >= (3, 9):
  96. return resources.files(data_module).joinpath(data_file_name).open("r")
  97. else:
  98. return resources.open_text(data_module, data_file_name)
  99. def _open_binary(data_module, data_file_name):
  100. if sys.version_info >= (3, 9):
  101. return resources.files(data_module).joinpath(data_file_name).open("rb")
  102. else:
  103. return resources.open_binary(data_module, data_file_name)
  104. def _read_text(descr_module, descr_file_name):
  105. if sys.version_info >= (3, 9):
  106. return resources.files(descr_module).joinpath(descr_file_name).read_text()
  107. else:
  108. return resources.read_text(descr_module, descr_file_name)
  109. def _path(data_module, data_file_name):
  110. if sys.version_info >= (3, 9):
  111. return resources.as_file(resources.files(data_module).joinpath(data_file_name))
  112. else:
  113. return resources.path(data_module, data_file_name)
  114. def _is_resource(data_module, data_file_name):
  115. if sys.version_info >= (3, 9):
  116. return resources.files(data_module).joinpath(data_file_name).is_file()
  117. else:
  118. return resources.is_resource(data_module, data_file_name)
  119. def _contents(data_module):
  120. if sys.version_info >= (3, 9):
  121. return (
  122. resource.name
  123. for resource in resources.files(data_module).iterdir()
  124. if resource.is_file()
  125. )
  126. else:
  127. return resources.contents(data_module)
  128. # For +1.25 NumPy versions exceptions and warnings are being moved
  129. # to a dedicated submodule.
  130. if np_version >= parse_version("1.25.0"):
  131. from numpy.exceptions import ComplexWarning, VisibleDeprecationWarning
  132. else:
  133. from numpy import ComplexWarning, VisibleDeprecationWarning # type: ignore # noqa
  134. # TODO: Remove when Scipy 1.6 is the minimum supported version
  135. try:
  136. from scipy.integrate import trapezoid # type: ignore # noqa
  137. except ImportError:
  138. from scipy.integrate import trapz as trapezoid # type: ignore # noqa