_sources.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # mypy: allow-untyped-defs
  2. import ast
  3. import functools
  4. import inspect
  5. from textwrap import dedent
  6. from typing import Any, List, NamedTuple, Optional, Tuple
  7. from torch._C import ErrorReport
  8. from torch._C._jit_tree_views import SourceRangeFactory
  9. def get_source_lines_and_file(
  10. obj: Any,
  11. error_msg: Optional[str] = None,
  12. ) -> Tuple[List[str], int, Optional[str]]:
  13. """
  14. Wrapper around inspect.getsourcelines and inspect.getsourcefile.
  15. Returns: (sourcelines, file_lino, filename)
  16. """
  17. filename = None # in case getsourcefile throws
  18. try:
  19. filename = inspect.getsourcefile(obj)
  20. sourcelines, file_lineno = inspect.getsourcelines(obj)
  21. except OSError as e:
  22. msg = (
  23. f"Can't get source for {obj}. TorchScript requires source access in "
  24. "order to carry out compilation, make sure original .py files are "
  25. "available."
  26. )
  27. if error_msg:
  28. msg += "\n" + error_msg
  29. raise OSError(msg) from e
  30. return sourcelines, file_lineno, filename
  31. def normalize_source_lines(sourcelines: List[str]) -> List[str]:
  32. """
  33. This helper function accepts a list of source lines. It finds the
  34. indentation level of the function definition (`def`), then it indents
  35. all lines in the function body to a point at or greater than that
  36. level. This allows for comments and continued string literals that
  37. are at a lower indentation than the rest of the code.
  38. Args:
  39. sourcelines: function source code, separated into lines by
  40. the '\n' character
  41. Returns:
  42. A list of source lines that have been correctly aligned
  43. """
  44. def remove_prefix(text, prefix):
  45. return text[text.startswith(prefix) and len(prefix) :]
  46. # Find the line and line number containing the function definition
  47. idx = None
  48. for i, l in enumerate(sourcelines):
  49. if l.lstrip().startswith("def"):
  50. idx = i
  51. break
  52. # This will happen when the function is a lambda- we won't find "def" anywhere in the source
  53. # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
  54. # `parse_def()`, but we might want to handle this case in the future.
  55. if idx is None:
  56. return sourcelines
  57. # Get a string representing the amount of leading whitespace
  58. fn_def = sourcelines[idx]
  59. whitespace = fn_def.split("def")[0]
  60. # Add this leading whitespace to all lines before and after the `def`
  61. aligned_prefix = [
  62. whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
  63. ]
  64. aligned_suffix = [
  65. whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
  66. ]
  67. # Put it together again
  68. aligned_prefix.append(fn_def)
  69. return aligned_prefix + aligned_suffix
  70. # Thin wrapper around SourceRangeFactory to store extra metadata
  71. # about the function-to-be-compiled.
  72. class SourceContext(SourceRangeFactory):
  73. def __init__(
  74. self,
  75. source,
  76. filename,
  77. file_lineno,
  78. leading_whitespace_len,
  79. uses_true_division=True,
  80. funcname=None,
  81. ):
  82. super().__init__(source, filename, file_lineno, leading_whitespace_len)
  83. self.uses_true_division = uses_true_division
  84. self.filename = filename
  85. self.funcname = funcname
  86. @functools.lru_cache(maxsize=None)
  87. def make_source_context(*args):
  88. return SourceContext(*args)
  89. def fake_range():
  90. return SourceContext("", None, 0, 0).make_raw_range(0, 1)
  91. class ParsedDef(NamedTuple):
  92. ast: ast.Module
  93. ctx: SourceContext
  94. source: str
  95. filename: Optional[str]
  96. file_lineno: int
  97. def parse_def(fn):
  98. sourcelines, file_lineno, filename = get_source_lines_and_file(
  99. fn, ErrorReport.call_stack()
  100. )
  101. sourcelines = normalize_source_lines(sourcelines)
  102. source = "".join(sourcelines)
  103. dedent_src = dedent(source)
  104. py_ast = ast.parse(dedent_src)
  105. if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
  106. raise RuntimeError(
  107. f"Expected a single top-level function: {filename}:{file_lineno}"
  108. )
  109. leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
  110. dedent_src.split("\n", 1)[0]
  111. )
  112. ctx = make_source_context(
  113. source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
  114. )
  115. return ParsedDef(py_ast, ctx, source, filename, file_lineno)