| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- """Utilities related to attribute docstring extraction."""
- from __future__ import annotations
- import ast
- import inspect
- import textwrap
- from typing import Any
- class DocstringVisitor(ast.NodeVisitor):
- def __init__(self) -> None:
- super().__init__()
- self.target: str | None = None
- self.attrs: dict[str, str] = {}
- self.previous_node_type: type[ast.AST] | None = None
- def visit(self, node: ast.AST) -> Any:
- node_result = super().visit(node)
- self.previous_node_type = type(node)
- return node_result
- def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
- if isinstance(node.target, ast.Name):
- self.target = node.target.id
- def visit_Expr(self, node: ast.Expr) -> Any:
- if (
- isinstance(node.value, ast.Constant)
- and isinstance(node.value.value, str)
- and self.previous_node_type is ast.AnnAssign
- ):
- docstring = inspect.cleandoc(node.value.value)
- if self.target:
- self.attrs[self.target] = docstring
- self.target = None
- def _dedent_source_lines(source: list[str]) -> str:
- # Required for nested class definitions, e.g. in a function block
- dedent_source = textwrap.dedent(''.join(source))
- if dedent_source.startswith((' ', '\t')):
- # We are in the case where there's a dedented (usually multiline) string
- # at a lower indentation level than the class itself. We wrap our class
- # in a function as a workaround.
- dedent_source = f'def dedent_workaround():\n{dedent_source}'
- return dedent_source
- def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
- frame = inspect.currentframe()
- while frame:
- if inspect.getmodule(frame) is inspect.getmodule(cls):
- lnum = frame.f_lineno
- try:
- lines, _ = inspect.findsource(frame)
- except OSError:
- # Source can't be retrieved (maybe because running in an interactive terminal),
- # we don't want to error here.
- pass
- else:
- block_lines = inspect.getblock(lines[lnum - 1 :])
- dedent_source = _dedent_source_lines(block_lines)
- try:
- block_tree = ast.parse(dedent_source)
- except SyntaxError:
- pass
- else:
- stmt = block_tree.body[0]
- if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
- # `_dedent_source_lines` wrapped the class around the workaround function
- stmt = stmt.body[0]
- if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
- return block_lines
- frame = frame.f_back
- def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
- """Map model attributes and their corresponding docstring.
- Args:
- cls: The class of the Pydantic model to inspect.
- use_inspect: Whether to skip usage of frames to find the object and use
- the `inspect` module instead.
- Returns:
- A mapping containing attribute names and their corresponding docstring.
- """
- if use_inspect:
- # Might not work as expected if two classes have the same name in the same source file.
- try:
- source, _ = inspect.getsourcelines(cls)
- except OSError:
- return {}
- else:
- source = _extract_source_from_frame(cls)
- if not source:
- return {}
- dedent_source = _dedent_source_lines(source)
- visitor = DocstringVisitor()
- visitor.visit(ast.parse(dedent_source))
- return visitor.attrs
|