_docs_extraction.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """Utilities related to attribute docstring extraction."""
  2. from __future__ import annotations
  3. import ast
  4. import inspect
  5. import textwrap
  6. from typing import Any
  7. class DocstringVisitor(ast.NodeVisitor):
  8. def __init__(self) -> None:
  9. super().__init__()
  10. self.target: str | None = None
  11. self.attrs: dict[str, str] = {}
  12. self.previous_node_type: type[ast.AST] | None = None
  13. def visit(self, node: ast.AST) -> Any:
  14. node_result = super().visit(node)
  15. self.previous_node_type = type(node)
  16. return node_result
  17. def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
  18. if isinstance(node.target, ast.Name):
  19. self.target = node.target.id
  20. def visit_Expr(self, node: ast.Expr) -> Any:
  21. if (
  22. isinstance(node.value, ast.Constant)
  23. and isinstance(node.value.value, str)
  24. and self.previous_node_type is ast.AnnAssign
  25. ):
  26. docstring = inspect.cleandoc(node.value.value)
  27. if self.target:
  28. self.attrs[self.target] = docstring
  29. self.target = None
  30. def _dedent_source_lines(source: list[str]) -> str:
  31. # Required for nested class definitions, e.g. in a function block
  32. dedent_source = textwrap.dedent(''.join(source))
  33. if dedent_source.startswith((' ', '\t')):
  34. # We are in the case where there's a dedented (usually multiline) string
  35. # at a lower indentation level than the class itself. We wrap our class
  36. # in a function as a workaround.
  37. dedent_source = f'def dedent_workaround():\n{dedent_source}'
  38. return dedent_source
  39. def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
  40. frame = inspect.currentframe()
  41. while frame:
  42. if inspect.getmodule(frame) is inspect.getmodule(cls):
  43. lnum = frame.f_lineno
  44. try:
  45. lines, _ = inspect.findsource(frame)
  46. except OSError:
  47. # Source can't be retrieved (maybe because running in an interactive terminal),
  48. # we don't want to error here.
  49. pass
  50. else:
  51. block_lines = inspect.getblock(lines[lnum - 1 :])
  52. dedent_source = _dedent_source_lines(block_lines)
  53. try:
  54. block_tree = ast.parse(dedent_source)
  55. except SyntaxError:
  56. pass
  57. else:
  58. stmt = block_tree.body[0]
  59. if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
  60. # `_dedent_source_lines` wrapped the class around the workaround function
  61. stmt = stmt.body[0]
  62. if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
  63. return block_lines
  64. frame = frame.f_back
  65. def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
  66. """Map model attributes and their corresponding docstring.
  67. Args:
  68. cls: The class of the Pydantic model to inspect.
  69. use_inspect: Whether to skip usage of frames to find the object and use
  70. the `inspect` module instead.
  71. Returns:
  72. A mapping containing attribute names and their corresponding docstring.
  73. """
  74. if use_inspect:
  75. # Might not work as expected if two classes have the same name in the same source file.
  76. try:
  77. source, _ = inspect.getsourcelines(cls)
  78. except OSError:
  79. return {}
  80. else:
  81. source = _extract_source_from_frame(cls)
  82. if not source:
  83. return {}
  84. dedent_source = _dedent_source_lines(source)
  85. visitor = DocstringVisitor()
  86. visitor.visit(ast.parse(dedent_source))
  87. return visitor.attrs