| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280 |
- """threadpoolctl
- This module provides utilities to introspect native libraries that relies on
- thread pools (notably BLAS and OpenMP implementations) and dynamically set the
- maximal number of threads they can use.
- """
- # License: BSD 3-Clause
- # The code to introspect dynamically loaded libraries on POSIX systems is
- # adapted from code by Intel developer @anton-malakhov available at
- # https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
- # and also published under the BSD 3-Clause license
- import os
- import re
- import sys
- import ctypes
- import itertools
- import textwrap
- from typing import final
- import warnings
- from ctypes.util import find_library
- from abc import ABC, abstractmethod
- from functools import lru_cache
- from contextlib import ContextDecorator
- __version__ = "3.5.0"
- __all__ = [
- "threadpool_limits",
- "threadpool_info",
- "ThreadpoolController",
- "LibController",
- "register",
- ]
- # One can get runtime errors or even segfaults due to multiple OpenMP libraries
- # loaded simultaneously which can happen easily in Python when importing and
- # using compiled extensions built with different compilers and therefore
- # different OpenMP runtimes in the same program. In particular libiomp (used by
- # Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
- # instance when calling BLAS inside a prange. Setting the following environment
- # variable allows multiple OpenMP libraries to be loaded. It should not degrade
- # performances since we manually take care of potential over-subscription
- # performance issues, in sections of the code where nested OpenMP loops can
- # happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
- # disable it while under the scope of the outer OpenMP parallel section.
- os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
- # Structure to cast the info on dynamically loaded library. See
- # https://linux.die.net/man/3/dl_iterate_phdr for more details.
- _SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32
- _SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16
- class _dl_phdr_info(ctypes.Structure):
- _fields_ = [
- ("dlpi_addr", _SYSTEM_UINT), # Base address of object
- ("dlpi_name", ctypes.c_char_p), # path to the library
- ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
- ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
- ]
- # The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
- try:
- _RTLD_NOLOAD = os.RTLD_NOLOAD
- except AttributeError:
- _RTLD_NOLOAD = ctypes.DEFAULT_MODE
- class LibController(ABC):
- """Abstract base class for the individual library controllers
- A library controller must expose the following class attributes:
- - user_api : str
- Usually the name of the library or generic specification the library
- implements, e.g. "blas" is a specification with different implementations.
- - internal_api : str
- Usually the name of the library or concrete implementation of some
- specification, e.g. "openblas" is an implementation of the "blas"
- specification.
- - filename_prefixes : tuple
- Possible prefixes of the shared library's filename that allow to
- identify the library. e.g. "libopenblas" for libopenblas.so.
- and implement the following methods: `get_num_threads`, `set_num_threads` and
- `get_version`.
- Threadpoolctl loops through all the loaded shared libraries and tries to match
- the filename of each library with the `filename_prefixes`. If a match is found, a
- controller is instantiated and a handler to the library is stored in the `dynlib`
- attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols
- of the shared library to implement the above methods.
- The following information will be exposed in the info dictionary:
- - user_api : standardized API, if any, or a copy of internal_api.
- - internal_api : implementation-specific API.
- - num_threads : the current thread limit.
- - prefix : prefix of the shared library's filename.
- - filepath : path to the loaded shared library.
- - version : version of the library (if available).
- In addition, each library controller may expose internal API specific entries. They
- must be set as attributes in the `set_additional_attributes` method.
- """
- @final
- def __init__(self, *, filepath=None, prefix=None, parent=None):
- """This is not meant to be overriden by subclasses."""
- self.parent = parent
- self.prefix = prefix
- self.filepath = filepath
- self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
- self._symbol_prefix, self._symbol_suffix = self._find_affixes()
- self.version = self.get_version()
- self.set_additional_attributes()
- def info(self):
- """Return relevant info wrapped in a dict"""
- hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
- return {
- "user_api": self.user_api,
- "internal_api": self.internal_api,
- "num_threads": self.num_threads,
- **{k: v for k, v in vars(self).items() if k not in hidden_attrs},
- }
- def set_additional_attributes(self):
- """Set additional attributes meant to be exposed in the info dict"""
- @property
- def num_threads(self):
- """Exposes the current thread limit as a dynamic property
- This is not meant to be used or overriden by subclasses.
- """
- return self.get_num_threads()
- @abstractmethod
- def get_num_threads(self):
- """Return the maximum number of threads available to use"""
- @abstractmethod
- def set_num_threads(self, num_threads):
- """Set the maximum number of threads to use"""
- @abstractmethod
- def get_version(self):
- """Return the version of the shared library"""
- def _find_affixes(self):
- """Return the affixes for the symbols of the shared library"""
- return "", ""
- def _get_symbol(self, name):
- """Return the symbol of the shared library accounding for the affixes"""
- return getattr(
- self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
- )
- class OpenBLASController(LibController):
- """Controller class for OpenBLAS"""
- user_api = "blas"
- internal_api = "openblas"
- filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")
- _symbol_prefixes = ("", "scipy_")
- _symbol_suffixes = ("", "64_", "_64")
- # All variations of "openblas_get_num_threads", accounting for the affixes
- check_symbols = tuple(
- f"{prefix}openblas_get_num_threads{suffix}"
- for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
- )
- def _find_affixes(self):
- for prefix, suffix in itertools.product(
- self._symbol_prefixes, self._symbol_suffixes
- ):
- if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
- return prefix, suffix
- def set_additional_attributes(self):
- self.threading_layer = self._get_threading_layer()
- self.architecture = self._get_architecture()
- def get_num_threads(self):
- get_num_threads_func = self._get_symbol("openblas_get_num_threads")
- if get_num_threads_func is not None:
- return get_num_threads_func()
- return None
- def set_num_threads(self, num_threads):
- set_num_threads_func = self._get_symbol("openblas_set_num_threads")
- if set_num_threads_func is not None:
- return set_num_threads_func(num_threads)
- return None
- def get_version(self):
- # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
- # did not expose its version before that.
- get_version_func = self._get_symbol("openblas_get_config")
- if get_version_func is not None:
- get_version_func.restype = ctypes.c_char_p
- config = get_version_func().split()
- if config[0] == b"OpenBLAS":
- return config[1].decode("utf-8")
- return None
- return None
- def _get_threading_layer(self):
- """Return the threading layer of OpenBLAS"""
- get_threading_layer_func = self._get_symbol("openblas_get_parallel")
- if get_threading_layer_func is not None:
- threading_layer = get_threading_layer_func()
- if threading_layer == 2:
- return "openmp"
- elif threading_layer == 1:
- return "pthreads"
- return "disabled"
- return "unknown"
- def _get_architecture(self):
- """Return the architecture detected by OpenBLAS"""
- get_architecture_func = self._get_symbol("openblas_get_corename")
- if get_architecture_func is not None:
- get_architecture_func.restype = ctypes.c_char_p
- return get_architecture_func().decode("utf-8")
- return None
- class BLISController(LibController):
- """Controller class for BLIS"""
- user_api = "blas"
- internal_api = "blis"
- filename_prefixes = ("libblis", "libblas")
- check_symbols = (
- "bli_thread_get_num_threads",
- "bli_thread_set_num_threads",
- "bli_info_get_version_str",
- "bli_info_get_enable_openmp",
- "bli_info_get_enable_pthreads",
- "bli_arch_query_id",
- "bli_arch_string",
- )
- def set_additional_attributes(self):
- self.threading_layer = self._get_threading_layer()
- self.architecture = self._get_architecture()
- def get_num_threads(self):
- get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None)
- num_threads = get_func()
- # by default BLIS is single-threaded and get_num_threads
- # returns -1. We map it to 1 for consistency with other libraries.
- return 1 if num_threads == -1 else num_threads
- def set_num_threads(self, num_threads):
- set_func = getattr(
- self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None
- )
- return set_func(num_threads)
- def get_version(self):
- get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None)
- if get_version_ is None:
- return None
- get_version_.restype = ctypes.c_char_p
- return get_version_().decode("utf-8")
- def _get_threading_layer(self):
- """Return the threading layer of BLIS"""
- if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)():
- return "openmp"
- elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)():
- return "pthreads"
- return "disabled"
- def _get_architecture(self):
- """Return the architecture detected by BLIS"""
- bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None)
- bli_arch_string = getattr(self.dynlib, "bli_arch_string", None)
- if bli_arch_query_id is None or bli_arch_string is None:
- return None
- # the true restype should be BLIS' arch_t (enum) but int should work
- # for us:
- bli_arch_query_id.restype = ctypes.c_int
- bli_arch_string.restype = ctypes.c_char_p
- return bli_arch_string(bli_arch_query_id()).decode("utf-8")
- class FlexiBLASController(LibController):
- """Controller class for FlexiBLAS"""
- user_api = "blas"
- internal_api = "flexiblas"
- filename_prefixes = ("libflexiblas",)
- check_symbols = (
- "flexiblas_get_num_threads",
- "flexiblas_set_num_threads",
- "flexiblas_get_version",
- "flexiblas_list",
- "flexiblas_list_loaded",
- "flexiblas_current_backend",
- )
- @property
- def loaded_backends(self):
- return self._get_backend_list(loaded=True)
- @property
- def current_backend(self):
- return self._get_current_backend()
- def info(self):
- """Return relevant info wrapped in a dict"""
- # We override the info method because the loaded and current backends
- # are dynamic properties
- exposed_attrs = super().info()
- exposed_attrs["loaded_backends"] = self.loaded_backends
- exposed_attrs["current_backend"] = self.current_backend
- return exposed_attrs
- def set_additional_attributes(self):
- self.available_backends = self._get_backend_list(loaded=False)
- def get_num_threads(self):
- get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None)
- num_threads = get_func()
- # by default BLIS is single-threaded and get_num_threads
- # returns -1. We map it to 1 for consistency with other libraries.
- return 1 if num_threads == -1 else num_threads
- def set_num_threads(self, num_threads):
- set_func = getattr(
- self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None
- )
- return set_func(num_threads)
- def get_version(self):
- get_version_ = getattr(self.dynlib, "flexiblas_get_version", None)
- if get_version_ is None:
- return None
- major = ctypes.c_int()
- minor = ctypes.c_int()
- patch = ctypes.c_int()
- get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
- return f"{major.value}.{minor.value}.{patch.value}"
- def _get_backend_list(self, loaded=False):
- """Return the list of available backends for FlexiBLAS.
- If loaded is False, return the list of available backends from the FlexiBLAS
- configuration. If loaded is True, return the list of actually loaded backends.
- """
- func_name = f"flexiblas_list{'_loaded' if loaded else ''}"
- get_backend_list_ = getattr(self.dynlib, func_name, None)
- if get_backend_list_ is None:
- return None
- n_backends = get_backend_list_(None, 0, 0)
- backends = []
- for i in range(n_backends):
- backend_name = ctypes.create_string_buffer(1024)
- get_backend_list_(backend_name, 1024, i)
- if backend_name.value.decode("utf-8") != "__FALLBACK__":
- # We don't know when to expect __FALLBACK__ but it is not a real
- # backend and does not show up when running flexiblas list.
- backends.append(backend_name.value.decode("utf-8"))
- return backends
- def _get_current_backend(self):
- """Return the backend of FlexiBLAS"""
- get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None)
- if get_backend_ is None:
- return None
- backend = ctypes.create_string_buffer(1024)
- get_backend_(backend, ctypes.sizeof(backend))
- return backend.value.decode("utf-8")
- def switch_backend(self, backend):
- """Switch the backend of FlexiBLAS
- Parameters
- ----------
- backend : str
- The name or the path to the shared library of the backend to switch to. If
- the backend is not already loaded, it will be loaded first.
- """
- if backend not in self.loaded_backends:
- if backend in self.available_backends:
- load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1)
- else: # assume backend is a path to a shared library
- load_func = getattr(
- self.dynlib, "flexiblas_load_backend_library", lambda _: -1
- )
- res = load_func(str(backend).encode("utf-8"))
- if res == -1:
- raise RuntimeError(
- f"Failed to load backend {backend!r}. It must either be the name of"
- " a backend available in the FlexiBLAS configuration "
- f"{self.available_backends} or the path to a valid shared library."
- )
- # Trigger a new search of loaded shared libraries since loading a new
- # backend caused a dlopen.
- self.parent._load_libraries()
- switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1)
- idx = self.loaded_backends.index(backend)
- res = switch_func(idx)
- if res == -1:
- raise RuntimeError(f"Failed to switch to backend {backend!r}.")
- class MKLController(LibController):
- """Controller class for MKL"""
- user_api = "blas"
- internal_api = "mkl"
- filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
- check_symbols = (
- "MKL_Get_Max_Threads",
- "MKL_Set_Num_Threads",
- "MKL_Get_Version_String",
- "MKL_Set_Threading_Layer",
- )
- def set_additional_attributes(self):
- self.threading_layer = self._get_threading_layer()
- def get_num_threads(self):
- get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None)
- return get_func()
- def set_num_threads(self, num_threads):
- set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None)
- return set_func(num_threads)
- def get_version(self):
- if not hasattr(self.dynlib, "MKL_Get_Version_String"):
- return None
- res = ctypes.create_string_buffer(200)
- self.dynlib.MKL_Get_Version_String(res, 200)
- version = res.value.decode("utf-8")
- group = re.search(r"Version ([^ ]+) ", version)
- if group is not None:
- version = group.groups()[0]
- return version.strip()
- def _get_threading_layer(self):
- """Return the threading layer of MKL"""
- # The function mkl_set_threading_layer returns the current threading
- # layer. Calling it with an invalid threading layer allows us to safely
- # get the threading layer
- set_threading_layer = getattr(
- self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
- )
- layer_map = {
- 0: "intel",
- 1: "sequential",
- 2: "pgi",
- 3: "gnu",
- 4: "tbb",
- -1: "not specified",
- }
- return layer_map[set_threading_layer(-1)]
- class OpenMPController(LibController):
- """Controller class for OpenMP"""
- user_api = "openmp"
- internal_api = "openmp"
- filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
- check_symbols = (
- "omp_get_max_threads",
- "omp_get_num_threads",
- )
- def get_num_threads(self):
- get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
- return get_func()
- def set_num_threads(self, num_threads):
- set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None)
- return set_func(num_threads)
- def get_version(self):
- # There is no way to get the version number programmatically in OpenMP.
- return None
- # Controllers for the libraries that we'll look for in the loaded libraries.
- # Third party libraries can register their own controllers.
- _ALL_CONTROLLERS = [
- OpenBLASController,
- BLISController,
- MKLController,
- OpenMPController,
- FlexiBLASController,
- ]
- # Helpers for the doc and test names
- _ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS))
- _ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS]
- _ALL_PREFIXES = list(
- set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes)
- )
- _ALL_BLAS_LIBRARIES = [
- lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas"
- ]
- _ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes
- def register(controller):
- """Register a new controller"""
- _ALL_CONTROLLERS.append(controller)
- _ALL_USER_APIS.append(controller.user_api)
- _ALL_INTERNAL_APIS.append(controller.internal_api)
- _ALL_PREFIXES.extend(controller.filename_prefixes)
- def _format_docstring(*args, **kwargs):
- def decorator(o):
- if o.__doc__ is not None:
- o.__doc__ = o.__doc__.format(*args, **kwargs)
- return o
- return decorator
- @lru_cache(maxsize=10000)
- def _realpath(filepath):
- """Small caching wrapper around os.path.realpath to limit system calls"""
- return os.path.realpath(filepath)
- @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
- def threadpool_info():
- """Return the maximal number of threads for each detected library.
- Return a list with all the supported libraries that have been found. Each
- library is represented by a dict with the following information:
- - "user_api" : user API. Possible values are {USER_APIS}.
- - "internal_api": internal API. Possible values are {INTERNAL_APIS}.
- - "prefix" : filename prefix of the specific implementation.
- - "filepath": path to the loaded library.
- - "version": version of the library (if available).
- - "num_threads": the current thread limit.
- In addition, each library may contain internal_api specific entries.
- """
- return ThreadpoolController().info()
- class _ThreadpoolLimiter:
- """The guts of ThreadpoolController.limit
- Refer to the docstring of ThreadpoolController.limit for more details.
- It will only act on the library controllers held by the provided `controller`.
- Using the default constructor sets the limits right away such that it can be used as
- a callable. Setting the limits can be delayed by using the `wrap` class method such
- that it can be used as a decorator.
- """
- def __init__(self, controller, *, limits=None, user_api=None):
- self._controller = controller
- self._limits, self._user_api, self._prefixes = self._check_params(
- limits, user_api
- )
- self._original_info = self._controller.info()
- self._set_threadpool_limits()
- def __enter__(self):
- return self
- def __exit__(self, type, value, traceback):
- self.restore_original_limits()
- @classmethod
- def wrap(cls, controller, *, limits=None, user_api=None):
- """Return an instance of this class that can be used as a decorator"""
- return _ThreadpoolLimiterDecorator(
- controller=controller, limits=limits, user_api=user_api
- )
- def restore_original_limits(self):
- """Set the limits back to their original values"""
- for lib_controller, original_info in zip(
- self._controller.lib_controllers, self._original_info
- ):
- lib_controller.set_num_threads(original_info["num_threads"])
- # Alias of `restore_original_limits` for backward compatibility
- unregister = restore_original_limits
- def get_original_num_threads(self):
- """Original num_threads from before calling threadpool_limits
- Return a dict `{user_api: num_threads}`.
- """
- num_threads = {}
- warning_apis = []
- for user_api in self._user_api:
- limits = [
- lib_info["num_threads"]
- for lib_info in self._original_info
- if lib_info["user_api"] == user_api
- ]
- limits = set(limits)
- n_limits = len(limits)
- if n_limits == 1:
- limit = limits.pop()
- elif n_limits == 0:
- limit = None
- else:
- limit = min(limits)
- warning_apis.append(user_api)
- num_threads[user_api] = limit
- if warning_apis:
- warnings.warn(
- "Multiple value possible for following user apis: "
- + ", ".join(warning_apis)
- + ". Returning the minimum."
- )
- return num_threads
- def _check_params(self, limits, user_api):
- """Suitable values for the _limits, _user_api and _prefixes attributes"""
- if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
- (
- limits,
- user_api,
- ) = self._controller._get_params_for_sequential_blas_under_openmp().values()
- if limits is None or isinstance(limits, int):
- if user_api is None:
- user_api = _ALL_USER_APIS
- elif user_api in _ALL_USER_APIS:
- user_api = [user_api]
- else:
- raise ValueError(
- f"user_api must be either in {_ALL_USER_APIS} or None. Got "
- f"{user_api} instead."
- )
- if limits is not None:
- limits = {api: limits for api in user_api}
- prefixes = []
- else:
- if isinstance(limits, list):
- # This should be a list of dicts of library info, for
- # compatibility with the result from threadpool_info.
- limits = {
- lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
- }
- elif isinstance(limits, ThreadpoolController):
- # To set the limits from the library controllers of a
- # ThreadpoolController object.
- limits = {
- lib_controller.prefix: lib_controller.num_threads
- for lib_controller in limits.lib_controllers
- }
- if not isinstance(limits, dict):
- raise TypeError(
- "limits must either be an int, a list, a dict, or "
- f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
- )
- # With a dictionary, can set both specific limit for given
- # libraries and global limit for user_api. Fetch each separately.
- prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
- user_api = [api for api in limits if api in _ALL_USER_APIS]
- return limits, user_api, prefixes
- def _set_threadpool_limits(self):
- """Change the maximal number of threads in selected thread pools.
- Return a list with all the supported libraries that have been found
- matching `self._prefixes` and `self._user_api`.
- """
- if self._limits is None:
- return
- for lib_controller in self._controller.lib_controllers:
- # self._limits is a dict {key: num_threads} where key is either
- # a prefix or a user_api. If a library matches both, the limit
- # corresponding to the prefix is chosen.
- if lib_controller.prefix in self._limits:
- num_threads = self._limits[lib_controller.prefix]
- elif lib_controller.user_api in self._limits:
- num_threads = self._limits[lib_controller.user_api]
- else:
- continue
- if num_threads is not None:
- lib_controller.set_num_threads(num_threads)
- class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
- """Same as _ThreadpoolLimiter but to be used as a decorator"""
- def __init__(self, controller, *, limits=None, user_api=None):
- self._limits, self._user_api, self._prefixes = self._check_params(
- limits, user_api
- )
- self._controller = controller
- def __enter__(self):
- # we need to set the limits here and not in the __init__ because we want the
- # limits to be set when calling the decorated function, not when creating the
- # decorator.
- self._original_info = self._controller.info()
- self._set_threadpool_limits()
- return self
- @_format_docstring(
- USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
- BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
- OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
- )
- class threadpool_limits(_ThreadpoolLimiter):
- """Change the maximal number of threads that can be used in thread pools.
- This object can be used either as a callable (the construction of this object
- limits the number of threads), as a context manager in a `with` block to
- automatically restore the original state of the controlled libraries when exiting
- the block, or as a decorator through its `wrap` method.
- Set the maximal number of threads that can be used in thread pools used in
- the supported libraries to `limit`. This function works for libraries that
- are already loaded in the interpreter and can be changed dynamically.
- This effect is global and impacts the whole Python process. There is no thread level
- isolation as these libraries do not offer thread-local APIs to configure the number
- of threads to use in nested parallel calls.
- Parameters
- ----------
- limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
- The maximal number of threads that can be used in thread pools
- - If int, sets the maximum number of threads to `limits` for each
- library selected by `user_api`.
- - If it is a dictionary `{{key: max_threads}}`, this function sets a
- custom maximum number of threads for each `key` which can be either a
- `user_api` or a `prefix` for a specific library.
- - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
- and `user_api` parameters for the specific use case of sequential BLAS
- calls within an OpenMP parallel region. The `user_api` parameter is
- ignored.
- - If None, this function does not do anything.
- user_api : {USER_APIS} or None (default=None)
- APIs of libraries to limit. Used only if `limits` is an int.
- - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- - If "openmp", it will only limit OpenMP supported libraries
- ({OPENMP_LIBS}). Note that it can affect the number of threads used
- by the BLAS libraries if they rely on OpenMP.
- - If None, this function will apply to all supported libraries.
- """
- def __init__(self, limits=None, user_api=None):
- super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
- @classmethod
- def wrap(cls, limits=None, user_api=None):
- return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
- class ThreadpoolController:
- """Collection of LibController objects for all loaded supported libraries
- Attributes
- ----------
- lib_controllers : list of `LibController` objects
- The list of library controllers of all loaded supported libraries.
- """
- # Cache for libc under POSIX and a few system libraries under Windows.
- # We use a class level cache instead of an instance level cache because
- # it's very unlikely that a shared library will be unloaded and reloaded
- # during the lifetime of a program.
- _system_libraries = dict()
- def __init__(self):
- self.lib_controllers = []
- self._load_libraries()
- self._warn_if_incompatible_openmp()
- @classmethod
- def _from_controllers(cls, lib_controllers):
- new_controller = cls.__new__(cls)
- new_controller.lib_controllers = lib_controllers
- return new_controller
- def info(self):
- """Return lib_controllers info as a list of dicts"""
- return [lib_controller.info() for lib_controller in self.lib_controllers]
- def select(self, **kwargs):
- """Return a ThreadpoolController containing a subset of its current
- library controllers
- It will select all libraries matching at least one pair (key, value) from kwargs
- where key is an entry of the library info dict (like "user_api", "internal_api",
- "prefix", ...) and value is the value or a list of acceptable values for that
- entry.
- For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
- will select all library controllers whose internal_api is either "blis" or
- "openblas".
- """
- for key, vals in kwargs.items():
- kwargs[key] = [vals] if not isinstance(vals, list) else vals
- lib_controllers = [
- lib_controller
- for lib_controller in self.lib_controllers
- if any(
- getattr(lib_controller, key, None) in vals
- for key, vals in kwargs.items()
- )
- ]
- return ThreadpoolController._from_controllers(lib_controllers)
- def _get_params_for_sequential_blas_under_openmp(self):
- """Return appropriate params to use for a sequential BLAS call in an OpenMP loop
- This function takes into account the unexpected behavior of OpenBLAS with the
- OpenMP threading layer.
- """
- if self.select(
- internal_api="openblas", threading_layer="openmp"
- ).lib_controllers:
- return {"limits": None, "user_api": None}
- return {"limits": 1, "user_api": "blas"}
- @_format_docstring(
- USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
- BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
- OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
- )
- def limit(self, *, limits=None, user_api=None):
- """Change the maximal number of threads that can be used in thread pools.
- This function returns an object that can be used either as a callable (the
- construction of this object limits the number of threads) or as a context
- manager, in a `with` block to automatically restore the original state of the
- controlled libraries when exiting the block.
- Set the maximal number of threads that can be used in thread pools used in
- the supported libraries to `limits`. This function works for libraries that
- are already loaded in the interpreter and can be changed dynamically.
- This effect is global and impacts the whole Python process. There is no thread
- level isolation as these libraries do not offer thread-local APIs to configure
- the number of threads to use in nested parallel calls.
- Parameters
- ----------
- limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
- The maximal number of threads that can be used in thread pools
- - If int, sets the maximum number of threads to `limits` for each
- library selected by `user_api`.
- - If it is a dictionary `{{key: max_threads}}`, this function sets a
- custom maximum number of threads for each `key` which can be either a
- `user_api` or a `prefix` for a specific library.
- - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
- and `user_api` parameters for the specific use case of sequential BLAS
- calls within an OpenMP parallel region. The `user_api` parameter is
- ignored.
- - If None, this function does not do anything.
- user_api : {USER_APIS} or None (default=None)
- APIs of libraries to limit. Used only if `limits` is an int.
- - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- - If "openmp", it will only limit OpenMP supported libraries
- ({OPENMP_LIBS}). Note that it can affect the number of threads used
- by the BLAS libraries if they rely on OpenMP.
- - If None, this function will apply to all supported libraries.
- """
- return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
- @_format_docstring(
- USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
- BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
- OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
- )
- def wrap(self, *, limits=None, user_api=None):
- """Change the maximal number of threads that can be used in thread pools.
- This function returns an object that can be used as a decorator.
- Set the maximal number of threads that can be used in thread pools used in
- the supported libraries to `limits`. This function works for libraries that
- are already loaded in the interpreter and can be changed dynamically.
- Parameters
- ----------
- limits : int, dict or None (default=None)
- The maximal number of threads that can be used in thread pools
- - If int, sets the maximum number of threads to `limits` for each
- library selected by `user_api`.
- - If it is a dictionary `{{key: max_threads}}`, this function sets a
- custom maximum number of threads for each `key` which can be either a
- `user_api` or a `prefix` for a specific library.
- - If None, this function does not do anything.
- user_api : {USER_APIS} or None (default=None)
- APIs of libraries to limit. Used only if `limits` is an int.
- - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
- - If "openmp", it will only limit OpenMP supported libraries
- ({OPENMP_LIBS}). Note that it can affect the number of threads used
- by the BLAS libraries if they rely on OpenMP.
- - If None, this function will apply to all supported libraries.
- """
- return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
- def __len__(self):
- return len(self.lib_controllers)
- def _load_libraries(self):
- """Loop through loaded shared libraries and store the supported ones"""
- if sys.platform == "darwin":
- self._find_libraries_with_dyld()
- elif sys.platform == "win32":
- self._find_libraries_with_enum_process_module_ex()
- elif "pyodide" in sys.modules:
- self._find_libraries_pyodide()
- else:
- self._find_libraries_with_dl_iterate_phdr()
- def _find_libraries_with_dl_iterate_phdr(self):
- """Loop through loaded libraries and return binders on supported ones
- This function is expected to work on POSIX system only.
- This code is adapted from code by Intel developer @anton-malakhov
- available at https://github.com/IntelPython/smp
- Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
- license
- """
- libc = self._get_libc()
- if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
- warnings.warn(
- "Could not find dl_iterate_phdr in the C standard library.",
- RuntimeWarning,
- )
- return []
- # Callback function for `dl_iterate_phdr` which is called for every
- # library loaded in the current process until it returns 1.
- def match_library_callback(info, size, data):
- # Get the path of the current library
- filepath = info.contents.dlpi_name
- if filepath:
- filepath = filepath.decode("utf-8")
- # Store the library controller if it is supported and selected
- self._make_controller_from_path(filepath)
- return 0
- c_func_signature = ctypes.CFUNCTYPE(
- ctypes.c_int, # Return type
- ctypes.POINTER(_dl_phdr_info),
- ctypes.c_size_t,
- ctypes.c_char_p,
- )
- c_match_library_callback = c_func_signature(match_library_callback)
- data = ctypes.c_char_p(b"")
- libc.dl_iterate_phdr(c_match_library_callback, data)
- def _find_libraries_with_dyld(self):
- """Loop through loaded libraries and return binders on supported ones
- This function is expected to work on OSX system only
- """
- libc = self._get_libc()
- if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
- warnings.warn(
- "Could not find _dyld_image_count in the C standard library.",
- RuntimeWarning,
- )
- return []
- n_dyld = libc._dyld_image_count()
- libc._dyld_get_image_name.restype = ctypes.c_char_p
- for i in range(n_dyld):
- filepath = ctypes.string_at(libc._dyld_get_image_name(i))
- filepath = filepath.decode("utf-8")
- # Store the library controller if it is supported and selected
- self._make_controller_from_path(filepath)
- def _find_libraries_with_enum_process_module_ex(self):
- """Loop through loaded libraries and return binders on supported ones
- This function is expected to work on windows system only.
- This code is adapted from code by Philipp Hagemeister @phihag available
- at https://stackoverflow.com/questions/17474574
- """
- from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
- PROCESS_QUERY_INFORMATION = 0x0400
- PROCESS_VM_READ = 0x0010
- LIST_LIBRARIES_ALL = 0x03
- ps_api = self._get_windll("Psapi")
- kernel_32 = self._get_windll("kernel32")
- h_process = kernel_32.OpenProcess(
- PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
- )
- if not h_process: # pragma: no cover
- raise OSError(f"Could not open PID {os.getpid()}")
- try:
- buf_count = 256
- needed = DWORD()
- # Grow the buffer until it becomes large enough to hold all the
- # module headers
- while True:
- buf = (HMODULE * buf_count)()
- buf_size = ctypes.sizeof(buf)
- if not ps_api.EnumProcessModulesEx(
- h_process,
- ctypes.byref(buf),
- buf_size,
- ctypes.byref(needed),
- LIST_LIBRARIES_ALL,
- ):
- raise OSError("EnumProcessModulesEx failed")
- if buf_size >= needed.value:
- break
- buf_count = needed.value // (buf_size // buf_count)
- count = needed.value // (buf_size // buf_count)
- h_modules = map(HMODULE, buf[:count])
- # Loop through all the module headers and get the library path
- buf = ctypes.create_unicode_buffer(MAX_PATH)
- n_size = DWORD()
- for h_module in h_modules:
- # Get the path of the current module
- if not ps_api.GetModuleFileNameExW(
- h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
- ):
- raise OSError("GetModuleFileNameEx failed")
- filepath = buf.value
- # Store the library controller if it is supported and selected
- self._make_controller_from_path(filepath)
- finally:
- kernel_32.CloseHandle(h_process)
- def _find_libraries_pyodide(self):
- """Pyodide specific implementation for finding loaded libraries.
- Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449.
- One day, we may have a simpler solution. libc dl_iterate_phdr needs to
- be implemented in Emscripten and exposed in Pyodide, see
- https://github.com/emscripten-core/emscripten/issues/21354 for more
- details.
- """
- try:
- from pyodide_js._module import LDSO
- except ImportError:
- warnings.warn(
- "Unable to import LDSO from pyodide_js._module. This should never "
- "happen."
- )
- return
- for filepath in LDSO.loadedLibsByName.as_object_map():
- # Some libraries are duplicated by Pyodide and do not exist in the
- # filesystem, so we first check for the existence of the file. For
- # more details, see
- # https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1947946728
- if os.path.exists(filepath):
- self._make_controller_from_path(filepath)
- def _make_controller_from_path(self, filepath):
- """Store a library controller if it is supported and selected"""
- # Required to resolve symlinks
- filepath = _realpath(filepath)
- # `lower` required to take account of OpenMP dll case on Windows
- # (vcomp, VCOMP, Vcomp, ...)
- filename = os.path.basename(filepath).lower()
- # Loop through supported libraries to find if this filename corresponds
- # to a supported one.
- for controller_class in _ALL_CONTROLLERS:
- # check if filename matches a supported prefix
- prefix = self._check_prefix(filename, controller_class.filename_prefixes)
- # filename does not match any of the prefixes of the candidate
- # library. move to next library.
- if prefix is None:
- continue
- # workaround for BLAS libraries packaged by conda-forge on windows, which
- # are all renamed "libblas.dll". We thus have to check to which BLAS
- # implementation it actually corresponds looking for implementation
- # specific symbols.
- if prefix == "libblas":
- if filename.endswith(".dll"):
- libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD)
- if not any(
- hasattr(libblas, func)
- for func in controller_class.check_symbols
- ):
- continue
- else:
- # We ignore libblas on other platforms than windows because there
- # might be a libblas dso comming with openblas for instance that
- # can't be used to instantiate a pertinent LibController (many
- # symbols are missing) and would create confusion by making a
- # duplicate entry in threadpool_info.
- continue
- # filename matches a prefix. Now we check if the library has the symbols we
- # are looking for. If none of the symbols exists, it's very likely not the
- # expected library (e.g. a library having a common prefix with one of the
- # our supported libraries). Otherwise, create and store the library
- # controller.
- lib_controller = controller_class(
- filepath=filepath, prefix=prefix, parent=self
- )
- if filepath in (lib.filepath for lib in self.lib_controllers):
- # We already have a controller for this library.
- continue
- if not hasattr(controller_class, "check_symbols") or any(
- hasattr(lib_controller.dynlib, func)
- for func in controller_class.check_symbols
- ):
- self.lib_controllers.append(lib_controller)
- def _check_prefix(self, library_basename, filename_prefixes):
- """Return the prefix library_basename starts with
- Return None if none matches.
- """
- for prefix in filename_prefixes:
- if library_basename.startswith(prefix):
- return prefix
- return None
- def _warn_if_incompatible_openmp(self):
- """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
- prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
- msg = textwrap.dedent(
- """
- Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
- the same time. Both libraries are known to be incompatible and this
- can cause random crashes or deadlocks on Linux when loaded in the
- same Python program.
- Using threadpoolctl may cause crashes or deadlocks. For more
- information and possible workarounds, please see
- https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
- """
- )
- if "libomp" in prefixes and "libiomp" in prefixes:
- warnings.warn(msg, RuntimeWarning)
- @classmethod
- def _get_libc(cls):
- """Load the lib-C for unix systems."""
- libc = cls._system_libraries.get("libc")
- if libc is None:
- # Remark: If libc is statically linked or if Python is linked against an
- # alternative implementation of libc like musl, find_library will return
- # None and CDLL will load the main program itself which should contain the
- # libc symbols. We still name it libc for convenience.
- # If the main program does not contain the libc symbols, it's ok because
- # we check their presence later anyway.
- libc = ctypes.CDLL(find_library("c"), mode=_RTLD_NOLOAD)
- cls._system_libraries["libc"] = libc
- return libc
- @classmethod
- def _get_windll(cls, dll_name):
- """Load a windows DLL"""
- dll = cls._system_libraries.get(dll_name)
- if dll is None:
- dll = ctypes.WinDLL(f"{dll_name}.dll")
- cls._system_libraries[dll_name] = dll
- return dll
- def _main():
- """Commandline interface to display thread-pool information and exit."""
- import argparse
- import importlib
- import json
- import sys
- parser = argparse.ArgumentParser(
- usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
- description="Display thread-pool information and exit.",
- )
- parser.add_argument(
- "-i",
- "--import",
- dest="modules",
- nargs="*",
- default=(),
- help="Python modules to import before introspecting thread-pools.",
- )
- parser.add_argument(
- "-c",
- "--command",
- help="a Python statement to execute before introspecting thread-pools.",
- )
- options = parser.parse_args(sys.argv[1:])
- for module in options.modules:
- try:
- importlib.import_module(module, package=None)
- except ImportError:
- print("WARNING: could not import", module, file=sys.stderr)
- if options.command:
- exec(options.command)
- print(json.dumps(threadpool_info(), indent=2))
- if __name__ == "__main__":
- _main()
|