discovery.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import inspect
  2. import pkgutil
  3. from importlib import import_module
  4. from operator import itemgetter
  5. from pathlib import Path
  6. _MODULE_TO_IGNORE = {
  7. "tests",
  8. "externals",
  9. "setup",
  10. "conftest",
  11. "experimental",
  12. "estimator_checks",
  13. }
  14. def all_estimators(type_filter=None):
  15. """Get a list of all estimators from `sklearn`.
  16. This function crawls the module and gets all classes that inherit
  17. from BaseEstimator. Classes that are defined in test-modules are not
  18. included.
  19. Parameters
  20. ----------
  21. type_filter : {"classifier", "regressor", "cluster", "transformer"} \
  22. or list of such str, default=None
  23. Which kind of estimators should be returned. If None, no filter is
  24. applied and all estimators are returned. Possible values are
  25. 'classifier', 'regressor', 'cluster' and 'transformer' to get
  26. estimators only of these specific types, or a list of these to
  27. get the estimators that fit at least one of the types.
  28. Returns
  29. -------
  30. estimators : list of tuples
  31. List of (name, class), where ``name`` is the class name as string
  32. and ``class`` is the actual type of the class.
  33. """
  34. # lazy import to avoid circular imports from sklearn.base
  35. from ..base import (
  36. BaseEstimator,
  37. ClassifierMixin,
  38. ClusterMixin,
  39. RegressorMixin,
  40. TransformerMixin,
  41. )
  42. from . import IS_PYPY
  43. from ._testing import ignore_warnings
  44. def is_abstract(c):
  45. if not (hasattr(c, "__abstractmethods__")):
  46. return False
  47. if not len(c.__abstractmethods__):
  48. return False
  49. return True
  50. all_classes = []
  51. root = str(Path(__file__).parent.parent) # sklearn package
  52. # Ignore deprecation warnings triggered at import time and from walking
  53. # packages
  54. with ignore_warnings(category=FutureWarning):
  55. for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
  56. module_parts = module_name.split(".")
  57. if (
  58. any(part in _MODULE_TO_IGNORE for part in module_parts)
  59. or "._" in module_name
  60. ):
  61. continue
  62. module = import_module(module_name)
  63. classes = inspect.getmembers(module, inspect.isclass)
  64. classes = [
  65. (name, est_cls) for name, est_cls in classes if not name.startswith("_")
  66. ]
  67. # TODO: Remove when FeatureHasher is implemented in PYPY
  68. # Skips FeatureHasher for PYPY
  69. if IS_PYPY and "feature_extraction" in module_name:
  70. classes = [
  71. (name, est_cls)
  72. for name, est_cls in classes
  73. if name == "FeatureHasher"
  74. ]
  75. all_classes.extend(classes)
  76. all_classes = set(all_classes)
  77. estimators = [
  78. c
  79. for c in all_classes
  80. if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
  81. ]
  82. # get rid of abstract base classes
  83. estimators = [c for c in estimators if not is_abstract(c[1])]
  84. if type_filter is not None:
  85. if not isinstance(type_filter, list):
  86. type_filter = [type_filter]
  87. else:
  88. type_filter = list(type_filter) # copy
  89. filtered_estimators = []
  90. filters = {
  91. "classifier": ClassifierMixin,
  92. "regressor": RegressorMixin,
  93. "transformer": TransformerMixin,
  94. "cluster": ClusterMixin,
  95. }
  96. for name, mixin in filters.items():
  97. if name in type_filter:
  98. type_filter.remove(name)
  99. filtered_estimators.extend(
  100. [est for est in estimators if issubclass(est[1], mixin)]
  101. )
  102. estimators = filtered_estimators
  103. if type_filter:
  104. raise ValueError(
  105. "Parameter type_filter must be 'classifier', "
  106. "'regressor', 'transformer', 'cluster' or "
  107. "None, got"
  108. f" {repr(type_filter)}."
  109. )
  110. # drop duplicates, sort for reproducibility
  111. # itemgetter is used to ensure the sort does not extend to the 2nd item of
  112. # the tuple
  113. return sorted(set(estimators), key=itemgetter(0))
  114. def all_displays():
  115. """Get a list of all displays from `sklearn`.
  116. Returns
  117. -------
  118. displays : list of tuples
  119. List of (name, class), where ``name`` is the display class name as
  120. string and ``class`` is the actual type of the class.
  121. """
  122. # lazy import to avoid circular imports from sklearn.base
  123. from ._testing import ignore_warnings
  124. all_classes = []
  125. root = str(Path(__file__).parent.parent) # sklearn package
  126. # Ignore deprecation warnings triggered at import time and from walking
  127. # packages
  128. with ignore_warnings(category=FutureWarning):
  129. for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
  130. module_parts = module_name.split(".")
  131. if (
  132. any(part in _MODULE_TO_IGNORE for part in module_parts)
  133. or "._" in module_name
  134. ):
  135. continue
  136. module = import_module(module_name)
  137. classes = inspect.getmembers(module, inspect.isclass)
  138. classes = [
  139. (name, display_class)
  140. for name, display_class in classes
  141. if not name.startswith("_") and name.endswith("Display")
  142. ]
  143. all_classes.extend(classes)
  144. return sorted(set(all_classes), key=itemgetter(0))
  145. def _is_checked_function(item):
  146. if not inspect.isfunction(item):
  147. return False
  148. if item.__name__.startswith("_"):
  149. return False
  150. mod = item.__module__
  151. if not mod.startswith("sklearn.") or mod.endswith("estimator_checks"):
  152. return False
  153. return True
  154. def all_functions():
  155. """Get a list of all functions from `sklearn`.
  156. Returns
  157. -------
  158. functions : list of tuples
  159. List of (name, function), where ``name`` is the function name as
  160. string and ``function`` is the actual function.
  161. """
  162. # lazy import to avoid circular imports from sklearn.base
  163. from ._testing import ignore_warnings
  164. all_functions = []
  165. root = str(Path(__file__).parent.parent) # sklearn package
  166. # Ignore deprecation warnings triggered at import time and from walking
  167. # packages
  168. with ignore_warnings(category=FutureWarning):
  169. for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="sklearn."):
  170. module_parts = module_name.split(".")
  171. if (
  172. any(part in _MODULE_TO_IGNORE for part in module_parts)
  173. or "._" in module_name
  174. ):
  175. continue
  176. module = import_module(module_name)
  177. functions = inspect.getmembers(module, _is_checked_function)
  178. functions = [
  179. (func.__name__, func)
  180. for name, func in functions
  181. if not name.startswith("_")
  182. ]
  183. all_functions.extend(functions)
  184. # drop duplicates, sort for reproducibility
  185. # itemgetter is used to ensure the sort does not extend to the 2nd item of
  186. # the tuple
  187. return sorted(set(all_functions), key=itemgetter(0))