test_docstrings.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import re
  2. from inspect import signature
  3. from typing import Optional
  4. import pytest
  5. # make it possible to discover experimental estimators when calling `all_estimators`
  6. from sklearn.experimental import (
  7. enable_halving_search_cv, # noqa
  8. enable_iterative_imputer, # noqa
  9. )
  10. from sklearn.utils.discovery import all_displays, all_estimators, all_functions
  11. numpydoc_validation = pytest.importorskip("numpydoc.validate")
  12. def get_all_methods():
  13. estimators = all_estimators()
  14. displays = all_displays()
  15. for name, Klass in estimators + displays:
  16. if name.startswith("_"):
  17. # skip private classes
  18. continue
  19. methods = []
  20. for name in dir(Klass):
  21. if name.startswith("_"):
  22. continue
  23. method_obj = getattr(Klass, name)
  24. if hasattr(method_obj, "__call__") or isinstance(method_obj, property):
  25. methods.append(name)
  26. methods.append(None)
  27. for method in sorted(methods, key=str):
  28. yield Klass, method
  29. def get_all_functions_names():
  30. functions = all_functions()
  31. for _, func in functions:
  32. # exclude functions from utils.fixex since they come from external packages
  33. if "utils.fixes" not in func.__module__:
  34. yield f"{func.__module__}.{func.__name__}"
  35. def filter_errors(errors, method, Klass=None):
  36. """
  37. Ignore some errors based on the method type.
  38. These rules are specific for scikit-learn."""
  39. for code, message in errors:
  40. # We ignore following error code,
  41. # - RT02: The first line of the Returns section
  42. # should contain only the type, ..
  43. # (as we may need refer to the name of the returned
  44. # object)
  45. # - GL01: Docstring text (summary) should start in the line
  46. # immediately after the opening quotes (not in the same line,
  47. # or leaving a blank line in between)
  48. # - GL02: If there's a blank line, it should be before the
  49. # first line of the Returns section, not after (it allows to have
  50. # short docstrings for properties).
  51. if code in ["RT02", "GL01", "GL02"]:
  52. continue
  53. # Ignore PR02: Unknown parameters for properties. We sometimes use
  54. # properties for ducktyping, i.e. SGDClassifier.predict_proba
  55. # Ignore GL08: Parsing of the method signature failed, possibly because this is
  56. # a property. Properties are sometimes used for deprecated attributes and the
  57. # attribute is already documented in the class docstring.
  58. #
  59. # All error codes:
  60. # https://numpydoc.readthedocs.io/en/latest/validation.html#built-in-validation-checks
  61. if code in ("PR02", "GL08") and Klass is not None and method is not None:
  62. method_obj = getattr(Klass, method)
  63. if isinstance(method_obj, property):
  64. continue
  65. # Following codes are only taken into account for the
  66. # top level class docstrings:
  67. # - ES01: No extended summary found
  68. # - SA01: See Also section not found
  69. # - EX01: No examples section found
  70. if method is not None and code in ["EX01", "SA01", "ES01"]:
  71. continue
  72. yield code, message
  73. def repr_errors(res, Klass=None, method: Optional[str] = None) -> str:
  74. """Pretty print original docstring and the obtained errors
  75. Parameters
  76. ----------
  77. res : dict
  78. result of numpydoc.validate.validate
  79. Klass : {Estimator, Display, None}
  80. estimator object or None
  81. method : str
  82. if estimator is not None, either the method name or None.
  83. Returns
  84. -------
  85. str
  86. String representation of the error.
  87. """
  88. if method is None:
  89. if hasattr(Klass, "__init__"):
  90. method = "__init__"
  91. elif Klass is None:
  92. raise ValueError("At least one of Klass, method should be provided")
  93. else:
  94. raise NotImplementedError
  95. if Klass is not None:
  96. obj = getattr(Klass, method)
  97. try:
  98. obj_signature = str(signature(obj))
  99. except TypeError:
  100. # In particular we can't parse the signature of properties
  101. obj_signature = (
  102. "\nParsing of the method signature failed, "
  103. "possibly because this is a property."
  104. )
  105. obj_name = Klass.__name__ + "." + method
  106. else:
  107. obj_signature = ""
  108. obj_name = method
  109. msg = "\n\n" + "\n\n".join(
  110. [
  111. str(res["file"]),
  112. obj_name + obj_signature,
  113. res["docstring"],
  114. "# Errors",
  115. "\n".join(
  116. " - {}: {}".format(code, message) for code, message in res["errors"]
  117. ),
  118. ]
  119. )
  120. return msg
  121. @pytest.mark.parametrize("function_name", get_all_functions_names())
  122. def test_function_docstring(function_name, request):
  123. """Check function docstrings using numpydoc."""
  124. res = numpydoc_validation.validate(function_name)
  125. res["errors"] = list(filter_errors(res["errors"], method="function"))
  126. if res["errors"]:
  127. msg = repr_errors(res, method=f"Tested function: {function_name}")
  128. raise ValueError(msg)
  129. @pytest.mark.parametrize("Klass, method", get_all_methods())
  130. def test_docstring(Klass, method, request):
  131. base_import_path = Klass.__module__
  132. import_path = [base_import_path, Klass.__name__]
  133. if method is not None:
  134. import_path.append(method)
  135. import_path = ".".join(import_path)
  136. res = numpydoc_validation.validate(import_path)
  137. res["errors"] = list(filter_errors(res["errors"], method, Klass=Klass))
  138. if res["errors"]:
  139. msg = repr_errors(res, Klass, method)
  140. raise ValueError(msg)
  141. if __name__ == "__main__":
  142. import argparse
  143. import sys
  144. parser = argparse.ArgumentParser(description="Validate docstring with numpydoc.")
  145. parser.add_argument("import_path", help="Import path to validate")
  146. args = parser.parse_args()
  147. res = numpydoc_validation.validate(args.import_path)
  148. import_path_sections = args.import_path.split(".")
  149. # When applied to classes, detect class method. For functions
  150. # method = None.
  151. # TODO: this detection can be improved. Currently we assume that we have
  152. # class # methods if the second path element before last is in camel case.
  153. if len(import_path_sections) >= 2 and re.match(
  154. r"(?:[A-Z][a-z]*)+", import_path_sections[-2]
  155. ):
  156. method = import_path_sections[-1]
  157. else:
  158. method = None
  159. res["errors"] = list(filter_errors(res["errors"], method))
  160. if res["errors"]:
  161. msg = repr_errors(res, method=args.import_path)
  162. print(msg)
  163. sys.exit(1)
  164. else:
  165. print("All docstring checks passed for {}!".format(args.import_path))