| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267 |
- # mypy: allow-untyped-defs
- import ast
- import dataclasses
- import inspect
- import re
- import string
- import sys
- from collections import namedtuple
- from textwrap import dedent
- from typing import List, Tuple # noqa: F401
- import torch
- import torch.jit.annotations
- from torch import _jit_internal
- from torch._C._jit_tree_views import (
- Apply,
- Assert,
- Assign,
- Attribute,
- AugAssign,
- BinOp,
- Break,
- ClassDef,
- Const,
- Continue,
- Decl,
- Def,
- Delete,
- DictComp,
- DictLiteral,
- Dots,
- EmptyTypeAnnotation,
- ExprStmt,
- FalseLiteral,
- For,
- Ident,
- If,
- ListComp,
- ListLiteral,
- NoneLiteral,
- Param,
- Pass,
- Property,
- Raise,
- Return,
- Select,
- SliceExpr,
- Starred,
- Stmt,
- StringLiteral,
- Subscript,
- TernaryIf,
- TrueLiteral,
- TupleLiteral,
- UnaryOp,
- Var,
- While,
- With,
- WithItem,
- )
- from torch._jit_internal import ( # noqa: F401
- _is_drop_fn,
- FunctionModifiers,
- is_static_fn,
- should_drop,
- )
- from torch._sources import (
- get_source_lines_and_file,
- make_source_context,
- parse_def,
- ParsedDef as _ParsedDef,
- )
- from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
- from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace
- _IS_ASTUNPARSE_INSTALLED = False
- try:
- import astunparse # type: ignore[import]
- _IS_ASTUNPARSE_INSTALLED = True
- except ImportError:
- pass
- # Borrowed from cPython implementation
- # https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
- _reserved_prefix = "__jit"
- _reserved_names = {"print"}
- _identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)
- def is_reserved_name(name):
- return name.startswith(_reserved_prefix) or name in _reserved_names
- pretty_node_names = {
- ast.FunctionDef: "function definitions",
- ast.For: "for loops",
- ast.Delete: "del statements",
- ast.ClassDef: "class definitions",
- ast.With: "with statements",
- ast.Raise: "raise statements",
- ast.Assert: "assertions",
- ast.Import: "import statements",
- ast.ImportFrom: "import statements",
- ast.Global: "global variables",
- ast.Break: "break statements",
- ast.Continue: "continue statements",
- }
- node_start_tokens = {
- ast.FunctionDef: "def",
- ast.For: "for",
- ast.Delete: "del",
- ast.ClassDef: "class",
- ast.With: "with",
- ast.Raise: "raise",
- ast.Assert: "assert",
- ast.Import: "import",
- ast.ImportFrom: "from",
- ast.Global: "global",
- ast.Break: "break",
- ast.Continue: "continue",
- }
- pretty_node_names.update(
- {
- ast.AsyncFunctionDef: "async function definitions",
- ast.AsyncFor: "async for loops",
- ast.AsyncWith: "async with statements",
- ast.Try: "try blocks",
- ast.Nonlocal: "nonlocal variables",
- }
- )
- node_start_tokens.update(
- {
- ast.AsyncFunctionDef: "async def",
- ast.AsyncFor: "async for",
- ast.AsyncWith: "async with",
- ast.Try: "try",
- ast.Nonlocal: "nonlocal",
- }
- )
- pretty_node_names.update(
- {
- ast.AnnAssign: "annotated assignments",
- }
- )
- # NB: no specific token for AnnAssign
- class FrontendError(Exception):
- def __init__(self, source_range, msg):
- self.source_range = source_range
- self.msg = msg
- # This has to be instantiated here so the ErrorReport is accurate to the
- # call stack when the FrontendError was raised
- self.error_report = torch._C.ErrorReport(self.source_range)
- def __str__(self):
- return self.msg + self.error_report.what().lstrip()
- class NotSupportedError(FrontendError):
- pass
- class UnsupportedNodeError(NotSupportedError):
- def __init__(self, ctx, offending_node, reason=""):
- # If we don't have a specific token, we default to length of 1
- node_type = type(offending_node)
- range_len = len(node_start_tokens.get(node_type, " "))
- source_range = ctx.make_range(
- offending_node.lineno,
- offending_node.col_offset,
- offending_node.col_offset + range_len,
- )
- feature_name = pretty_node_names.get(node_type, node_type.__name__)
- msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported"
- super().__init__(source_range, msg)
- class FrontendTypeError(FrontendError):
- pass
- def build_withitems(ctx, items):
- items = [build_withitem(ctx, i) for i in items]
- return list(items)
- def build_stmts(ctx, stmts):
- stmts = [build_stmt(ctx, s) for s in stmts]
- return list(filter(None, stmts))
- def get_class_properties(cls, self_name):
- """
- Get a list of Property objects representing the properties of a class.
- Args:
- cls: The class to get properties of.
- self_name: The name of the class that the properties should belong to.
- Returns:
- A list of Property objects corresponding to the properties of cls. Property
- here refers to the subclass of TreeView.
- """
- props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property))
- # Any property that should not compiled must be in this list on the Module.
- unused_properties = getattr(cls, "__jit_unused_properties__", [])
- # Create Property TreeView objects from inspected property objects.
- properties = []
- for prop in props:
- if prop[0] not in unused_properties and not should_drop(prop[1].fget):
- getter = get_jit_def(
- prop[1].fget, f"__{prop[0]}_getter", self_name=self_name
- )
- setter = (
- get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name)
- if prop[1].fset
- else None
- )
- properties.append(
- Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)
- )
- return properties
- def get_class_assigns(ctx, cls_ast):
- assigns = []
- def maybe_build_assign(builder, entry):
- nonlocal assigns
- try:
- assigns.append(builder(ctx, entry))
- except NotSupportedError:
- pass
- for entry in cls_ast.body:
- if isinstance(entry, ast.Assign):
- maybe_build_assign(StmtBuilder.build_Assign, entry)
- elif isinstance(entry, ast.AnnAssign):
- maybe_build_assign(StmtBuilder.build_AnnAssign, entry)
- return assigns
- def get_jit_class_def(cls, self_name):
- # Get defs for each method within the current class independently
- # TODO: proper overriding analysis when implementing class inheritance
- methods = inspect.getmembers(
- cls,
- predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
- and not is_static_fn(cls, m.__name__)
- and m.__name__ in cls.__dict__
- and not _is_drop_fn(m),
- )
- def is_classmethod(fn):
- return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls
- # Get and parse the source code for this class
- sourcelines, file_lineno, filename = get_source_lines_and_file(
- cls, torch._C.ErrorReport.call_stack()
- )
- source = "".join(sourcelines)
- dedent_src = dedent(source)
- py_ast = ast.parse(dedent_src)
- class_ast = py_ast.body[0]
- assert isinstance(class_ast, ast.ClassDef)
- # Special case for dataclasses. In general we need access to the source code for
- # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
- # magic methods for classes, and we can't get the source code for these methods. As a
- # workaround, we synthesize TorchScript-friendly implementations ourselves.
- if dataclasses.is_dataclass(cls):
- # Detect whether the user manually implemented any of the magic methods. If they did,
- # we don't want to synthesize/override them.
- overrides = {
- method.name
- for method in class_ast.body
- if isinstance(method, ast.FunctionDef)
- and method.name in DATACLASS_MAGIC_METHODS
- }
- for i, (name, _) in enumerate(methods):
- # Is this a magic method we can synthesize?
- synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
- if synthesizer_fn and name not in overrides:
- parsed_def = synthesizer_fn(cls)
- methods[i] = name, parsed_def
- func = getattr(cls, name)
- _jit_internal.loader.cache(func, parsed_def.source)
- method_defs = [
- get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj))
- for (name, obj) in methods
- ]
- properties = get_class_properties(cls, self_name)
- leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
- dedent_src.split("\n", 1)[0]
- )
- ctx = make_source_context(
- source, filename, file_lineno, leading_whitespace_len, False
- )
- assigns = get_class_assigns(ctx, class_ast)
- return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)
- def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
- """
- Build a JIT AST (TreeView) from the given function.
- Args:
- fn: A function object to compile or a pre-parsed ParsedDef object
- def_name: The name to give to the resulting AST object. This is not
- always the same as `fn.__name__`, for example:
- def _forward(self):
- ...
- forward = _forward
- In this case, the `__name__` attribute of the function object is "_forward",
- but we want the result AST to have the name "forward".
- self_name: If this function is a method, what the type name of `self` is.
- """
- parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn
- type_line = torch.jit.annotations.get_type_line(parsed_def.source)
- fn_def = parsed_def.ast.body[0]
- if is_classmethod:
- arg_name = fn_def.args.args[0].arg
- # Insert a statement that assigns the first argument to the class
- assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
- fn_def.body.insert(0, assign_stmt)
- # Swap out the function signature and body if it is unused
- if should_drop(fn):
- unused_fn_def = ast.parse(
- 'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")'
- )
- if len(unused_fn_def.body) != 1 or not isinstance(
- unused_fn_def.body[0], ast.FunctionDef
- ):
- raise RuntimeError(
- f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}"
- )
- unused_def = unused_fn_def.body[0]
- fn_def.body = unused_def.body
- # kwarg/vararg not supported by `build_def`
- fn_def.args.kwarg = fn_def.args.vararg = None
- for arg in fn_def.args.args + fn_def.args.kwonlyargs:
- # Replace potentially unsupported type annotations by "Any"
- arg.annotation = unused_def.args.args[0].annotation
- if _is_drop_fn(fn):
- # Dropping potentially unsupported return type annotation for jit._drop
- fn_def.returns = None
- fn_def.type_comment = None
- # If MonkeyType is installed, get all the consolidated type traces
- # for the arguments from type_trace_db
- type_trace_db = torch.jit._script._get_type_trace_db()
- pdt_arg_types = None
- if monkeytype_trace and not isinstance(fn, _ParsedDef): # type: ignore[truthy-function]
- qualname = get_qualified_name(fn)
- pdt_arg_types = type_trace_db.get_args_types(qualname)
- return build_def(
- parsed_def.ctx,
- fn_def,
- type_line,
- def_name,
- self_name=self_name,
- pdt_arg_types=pdt_arg_types,
- )
- # TODO: more robust handling of recognizing ignore context manager
- def is_torch_jit_ignore_context_manager(stmt):
- # checks if the statement is torch.jit.ignore context manager
- if isinstance(stmt.items[0].context_expr, ast.Call):
- # extract torch part
- function = stmt.items[0].context_expr.func
- if isinstance(function, ast.Attribute):
- attr_name = function.attr
- attr_value = function.value
- if attr_name == "_IgnoreContextManager" and isinstance(
- attr_value, ast.Attribute
- ):
- # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager)
- if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name):
- if attr_value.value.id == "torch":
- return True
- return False
- class Builder:
- def __call__(self, ctx, node):
- method = getattr(self, "build_" + node.__class__.__name__, None)
- if method is None:
- raise UnsupportedNodeError(ctx, node)
- return method(ctx, node)
- def build_class_def(ctx, py_def, methods, properties, self_name, assigns):
- r = ctx.make_range(
- py_def.lineno, py_def.col_offset, py_def.col_offset + len("class")
- )
- return ClassDef(
- Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns
- )
- def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
- body = py_def.body
- r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def"))
- param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
- return_type = None
- if getattr(py_def, "returns", None) is not None:
- return_type = build_expr(ctx, py_def.returns)
- decl = Decl(r, param_list, return_type)
- is_method = self_name is not None
- if type_line is not None:
- type_comment_decl = torch._C.parse_type_comment(type_line)
- decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
- return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
- _vararg_kwarg_err = (
- "Compiled functions can't take variable number of arguments "
- "or use keyword-only arguments with defaults"
- )
- def build_param_list(ctx, py_args, self_name, pdt_arg_types=None):
- if py_args.kwarg is not None:
- expr = py_args.kwarg
- ctx_range = ctx.make_range(
- expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
- )
- raise NotSupportedError(ctx_range, _vararg_kwarg_err)
- if py_args.vararg is not None:
- expr = py_args.vararg
- ctx_range = ctx.make_range(
- expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
- )
- raise NotSupportedError(ctx_range, _vararg_kwarg_err)
- if len(py_args.kw_defaults) > 0:
- # kw_defaults is a list of the values for the kwargs (which default to None),
- # so they don't actually have line numbers.
- for arg in py_args.kw_defaults:
- if arg is not None:
- ctx_range = build_expr(ctx, arg).range()
- raise NotSupportedError(ctx_range, _vararg_kwarg_err)
- # List of Tuple of args and type as inferred by profile directed typing
- arg_and_types = [
- (
- arg,
- pdt_arg_types[arg.arg]
- if pdt_arg_types and bool(pdt_arg_types[arg.arg])
- else None,
- )
- for arg in py_args.args
- ]
- arg_and_types_kwonlyargs = [
- (
- arg,
- pdt_arg_types[arg.arg]
- if pdt_arg_types and bool(pdt_arg_types[arg.arg])
- else None,
- )
- for arg in py_args.kwonlyargs
- ]
- result = [
- build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type)
- for arg, arg_type in arg_and_types
- ]
- result += [
- build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type)
- for arg, arg_type in arg_and_types_kwonlyargs
- ]
- return result
- def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
- # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
- name = py_arg.arg
- r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
- if getattr(py_arg, "annotation", None) is not None:
- annotation_expr = build_expr(ctx, py_arg.annotation)
- elif pdt_arg_type:
- annotation_expr = Var(Ident(r, pdt_arg_type))
- elif self_name is not None and name == "self":
- annotation_expr = Var(Ident(r, self_name))
- else:
- annotation_expr = EmptyTypeAnnotation(r)
- return Param(annotation_expr, Ident(r, name), kwarg_only)
- def build_ignore_context_manager(ctx, stmt):
- InputType = namedtuple("InputType", ["name", "ann"])
- OutputType = namedtuple("OutputType", ["name", "ann"])
- def process_ins_outs(args):
- # parse the context manager to figure out inputs and outputs
- # with their annotated types
- # TODO: add input, output validator
- inputs = []
- outputs = []
- for arg in args:
- var_name = arg.arg
- var_ann = arg.value.value
- var_decl_type, var_ann = var_ann.split(":")
- if var_decl_type == "inp":
- inputs.append(InputType(var_name, var_ann))
- if var_decl_type == "out":
- outputs.append(OutputType(var_name, var_ann))
- return inputs, outputs
- def create_unique_name_ext(ctx, stmt):
- # extension will be based on the full path filename plus
- # the line number of original context manager
- fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename)
- return f"{fn}_{stmt.lineno}"
- def build_return_ann_stmt(outputs):
- return_type_ann = ""
- return_statement_str = "return "
- if len(outputs) == 0:
- return_type_ann += " -> None"
- if len(outputs) == 1:
- return_type_ann = " -> " + outputs[0].ann
- return_statement_str += outputs[0].name
- if len(outputs) > 1:
- return_type_ann = " -> Tuple"
- return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
- return_statement_str += ", ".join([var.name for var in outputs])
- return return_type_ann, return_statement_str
- def build_args(args):
- return ", ".join([arg.name for arg in args])
- inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)
- # build the replacement function str with given inputs and outputs
- ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
- ignore_function_str = "\ndef " + ignore_function_name
- ignore_function_str += (
- "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"
- )
- return_ann, return_stmt = build_return_ann_stmt(outputs)
- ignore_function_str += return_ann + ": pass"
- # first create the functionDef object from just declaration
- ignore_function = ast.parse(ignore_function_str).body[0]
- # dump the body of context manager to dummy function
- ignore_function.body = stmt.body # type: ignore[attr-defined]
- # insert return statement to the function
- return_stmt = ast.parse(return_stmt).body[0]
- ignore_function.body.append(return_stmt) # type: ignore[attr-defined]
- # registers the custom function in the global context
- ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
- ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}'
- exec(ignore_func_str) # noqa: P204
- # build the statements as:
- # <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
- assign_str_lhs = build_args(outputs)
- # this function will be registered in torch.jit.frontend module by default
- assign_str_rhs = (
- f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")"
- )
- if len(outputs) > 0:
- assign_str = assign_str_lhs + " = " + assign_str_rhs
- else:
- assign_str = assign_str_rhs
- assign_ast = ast.parse(assign_str).body[0]
- return assign_ast
- def get_default_args(fn):
- if fn is None:
- return {}
- signature = inspect.signature(fn)
- return {
- k: v.default
- for k, v in signature.parameters.items()
- if v.default is not inspect.Parameter.empty
- }
- def get_default_args_for_class(cls):
- """
- Get default arguments for all methods in a class (except for static methods).
- Args:
- cls: type - The class type to inspect for default arguments.
- Returns:
- A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any]
- that maps each argument name to its default value.
- """
- # Get methods (except static methods because those are compiled separately as
- # if they were independent script functions).
- methods = inspect.getmembers(
- cls,
- predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
- and not is_static_fn(cls, m.__name__)
- and m.__name__ in cls.__dict__,
- )
- # Get method defaults. Property defaults do not need to be considered
- # because setters cannot be invoked without a value.
- defaults = {
- method_name: get_default_args(method_impl)
- for method_name, method_impl in methods
- }
- return defaults
- class WithItemBuilder(Builder):
- @staticmethod
- def build_withitem(ctx, item):
- lineno = item.context_expr.lineno
- start = item.context_expr.col_offset
- end = start + len(pretty_node_names[ast.With])
- op_vars = item.optional_vars
- r = ctx.make_range(lineno, start, end)
- return WithItem(
- r,
- build_expr(ctx, item.context_expr),
- build_expr(ctx, op_vars) if op_vars else None,
- )
- class StmtBuilder(Builder):
- augassign_map = {
- ast.Add: "+",
- ast.Sub: "-",
- ast.Mult: "*",
- ast.Div: "/",
- ast.Mod: "%",
- ast.BitOr: "|",
- ast.BitAnd: "&",
- ast.BitXor: "^",
- ast.LShift: "<<",
- ast.RShift: ">>",
- ast.Pow: "**",
- }
- @staticmethod
- def build_Expr(ctx, stmt):
- value = stmt.value
- if value.__class__.__name__ == "Str":
- # If a statement is a string literal expression,
- # then it is a docstring. Just ignore it.
- return None
- else:
- return ExprStmt(build_expr(ctx, value))
- @staticmethod
- def build_Assign(ctx, stmt):
- rhs = build_expr(ctx, stmt.value)
- lhs = [build_expr(ctx, x) for x in stmt.targets]
- return Assign(lhs, rhs)
- @staticmethod
- def build_AnnAssign(ctx, stmt):
- if stmt.value is None:
- raise UnsupportedNodeError(ctx, stmt, reason="without assigned value")
- # Disallow type annotations on instance attributes outside of __init__
- if (
- type(stmt.target) == ast.Attribute
- and stmt.target.value.id == "self" # type: ignore[attr-defined]
- and ctx.funcname != "__init__"
- ):
- start = stmt.col_offset
- end = start + len(f"self.{stmt.target.attr}")
- if hasattr(stmt.annotation, "id"):
- end += len(f": {stmt.annotation.id}")
- sr = ctx.make_range(stmt.lineno, start, end)
- raise ValueError(
- "Type annotations on instance attributes must be declared in "
- f"__init__, not '{ctx.funcname}': {sr}"
- )
- rhs = build_expr(ctx, stmt.value)
- lhs = build_expr(ctx, stmt.target)
- the_type = build_expr(ctx, stmt.annotation)
- return Assign([lhs], rhs, the_type)
- @staticmethod
- def build_Delete(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del"))
- return Delete(r, [build_expr(ctx, target) for target in stmt.targets])
- @staticmethod
- def build_Return(ctx, stmt):
- r = ctx.make_range(
- stmt.lineno, stmt.col_offset, stmt.col_offset + len("return")
- )
- return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))
- @staticmethod
- def build_Raise(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
- expr = build_expr(ctx, stmt.exc)
- return Raise(r, expr)
- @staticmethod
- def build_Assert(ctx, stmt):
- r = ctx.make_range(
- stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert")
- )
- test = build_expr(ctx, stmt.test)
- msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None
- return Assert(r, test, msg)
- @staticmethod
- def build_AugAssign(ctx, stmt):
- lhs = build_expr(ctx, stmt.target)
- rhs = build_expr(ctx, stmt.value)
- op = type(stmt.op)
- if op in StmtBuilder.augassign_map:
- op_token = StmtBuilder.augassign_map[op]
- else:
- raise NotSupportedError(
- find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)),
- "unsupported kind of augmented assignment: " + op.__name__,
- )
- return AugAssign(lhs, op_token, rhs)
- @staticmethod
- def build_While(ctx, stmt):
- if stmt.orelse:
- # TODO: try to recover the location of else:? Python doesn't give us useful
- # annotations in this case
- raise NotSupportedError(
- None, "else branches of while loops aren't supported"
- )
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
- return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body))
- @staticmethod
- def build_For(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for"))
- if stmt.orelse:
- raise NotSupportedError(r, "else branches of for loops aren't supported")
- return For(
- r,
- [build_expr(ctx, stmt.target)],
- [build_expr(ctx, stmt.iter)],
- build_stmts(ctx, stmt.body),
- )
- @staticmethod
- def build_If(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
- return If(
- r,
- build_expr(ctx, stmt.test),
- build_stmts(ctx, stmt.body),
- build_stmts(ctx, stmt.orelse),
- )
- @staticmethod
- def build_Print(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print"))
- if stmt.dest:
- raise NotSupportedError(
- r, "print statements with non-default destinations aren't supported"
- )
- args = [build_expr(ctx, val) for val in stmt.values]
- return ExprStmt(Apply(Var(Ident(r, "print")), args, []))
- @staticmethod
- def build_Pass(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass"))
- return Pass(r)
- @staticmethod
- def build_Break(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break"))
- return Break(r)
- @staticmethod
- def build_Continue(ctx, stmt):
- r = ctx.make_range(
- stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue")
- )
- return Continue(r)
- @staticmethod
- def build_With(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with"))
- # Handle ignore context manager
- if is_torch_jit_ignore_context_manager(stmt):
- if not _IS_ASTUNPARSE_INSTALLED:
- raise RuntimeError(
- "torch.jit._IgnoreContextManager requires installing Python library `astunparse`, \
- please install it in your Python environment"
- )
- assign_ast = build_ignore_context_manager(ctx, stmt)
- return build_stmt(ctx, assign_ast)
- return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body))
- class ExprBuilder(Builder):
- binop_map = {
- ast.Add: "+",
- ast.Sub: "-",
- ast.Mult: "*",
- ast.Div: "/",
- ast.Pow: "**",
- ast.Mod: "%",
- ast.FloorDiv: "//",
- ast.BitAnd: "&",
- ast.BitXor: "^",
- ast.BitOr: "|",
- ast.LShift: "<<",
- ast.RShift: ">>",
- }
- binop_map[ast.MatMult] = "@"
- unop_map = {
- ast.Not: "not",
- ast.USub: "-",
- ast.Invert: "~",
- }
- boolop_map = {
- ast.And: "and",
- ast.Or: "or",
- }
- cmpop_map = {
- ast.Eq: "==",
- ast.NotEq: "!=",
- ast.LtE: "<=",
- ast.Lt: "<",
- ast.GtE: ">=",
- ast.Gt: ">",
- ast.Is: "is",
- ast.IsNot: "is not",
- ast.In: "in",
- ast.NotIn: "not in",
- }
- @staticmethod
- def build_Attribute(ctx, expr):
- base = build_expr(ctx, expr.value)
- # expr.attr is just a string, so it's not annotated in any way, so we have
- # to build the range manually
- source = ctx.source.encode("utf-8")
- def get_char(index):
- return chr(source[index])
- start_pos = base.range().end + 1
- while get_char(start_pos) in string.whitespace: # Skip whitespace
- start_pos += 1
- end_pos = start_pos + len(expr.attr)
- name_range = ctx.make_raw_range(start_pos, end_pos)
- return Select(base, Ident(name_range, expr.attr))
- @staticmethod
- def build_Call(ctx, expr):
- func = build_expr(ctx, expr.func)
- args = [build_expr(ctx, py_arg) for py_arg in expr.args]
- if hasattr(expr, "starargs") and expr.starargs:
- stararg_expr = build_expr(ctx, expr.starargs)
- args += [Starred(stararg_expr.range(), stararg_expr)]
- kwargs = []
- for kw in expr.keywords:
- kw_expr = build_expr(ctx, kw.value)
- # XXX: we could do a better job at figuring out the range for the name here
- if not kw.arg:
- raise NotSupportedError(
- kw_expr.range(), "keyword-arg expansion is not supported"
- )
- kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
- return Apply(func, args, kwargs)
- @staticmethod
- def build_Ellipsis(ctx, expr):
- r = ctx.make_range(
- expr.lineno, expr.col_offset, expr.col_offset + 3
- ) # len("...") == 3
- return Dots(r)
- @staticmethod
- def build_Name(ctx, expr):
- r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
- if expr.id.startswith(_reserved_prefix):
- raise NotSupportedError(
- r,
- "names of variables used in JIT-ed functions "
- "can't start with " + _reserved_prefix,
- )
- if expr.id == "True":
- return TrueLiteral(r)
- elif expr.id == "False":
- return FalseLiteral(r)
- elif expr.id == "None":
- return NoneLiteral(r)
- elif expr.id == "Ellipsis":
- return Dots(r)
- return Var(Ident(r, expr.id))
- @staticmethod
- def build_NameConstant(ctx, expr):
- r = ctx.make_range(
- expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value))
- )
- if expr.value is True:
- return TrueLiteral(r)
- elif expr.value is False:
- return FalseLiteral(r)
- elif expr.value is None:
- return NoneLiteral(r)
- elif expr.value == Ellipsis:
- return Dots(r)
- else:
- raise ValueError("Name constant value unsupported: " + str(expr.value))
- @staticmethod
- def build_BinOp(ctx, expr):
- lhs = build_expr(ctx, expr.left)
- rhs = build_expr(ctx, expr.right)
- op = type(expr.op)
- if op == ast.Div and not ctx.uses_true_division:
- err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
- raise FrontendError(
- err_range,
- "Division of ints in TorchScript uses Python 3 true "
- "division semantics. Please put `from __future__ "
- "import division` at the top of your file",
- )
- op_token = ExprBuilder.binop_map.get(op)
- if op_token is None:
- err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
- raise NotSupportedError(
- err_range, "unsupported binary operator: " + op.__name__
- )
- return BinOp(op_token, lhs, rhs)
- @staticmethod
- def build_UnaryOp(ctx, expr):
- sub_expr = build_expr(ctx, expr.operand)
- op = type(expr.op)
- op_token = ExprBuilder.unop_map.get(op)
- if op_token is None:
- raise NotSupportedError(
- expr.range(), "unsupported unary operator: " + op.__name__
- )
- r = ctx.make_range(
- expr.lineno, expr.col_offset, expr.col_offset + len(op_token)
- )
- return UnaryOp(r, op_token, sub_expr)
- @staticmethod
- def build_BoolOp(ctx, expr):
- if len(expr.values) < 2:
- raise AssertionError(
- "expected at least 2 values in BoolOp, but got " + str(len(expr.values))
- )
- sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
- op = type(expr.op)
- op_token = ExprBuilder.boolop_map.get(op)
- if op_token is None:
- err_range = ctx.make_raw_range(
- sub_exprs[0].range().end, sub_exprs[1].range().start
- )
- raise NotSupportedError(
- err_range, "unsupported boolean operator: " + op.__name__
- )
- lhs = sub_exprs[0]
- for rhs in sub_exprs[1:]:
- lhs = BinOp(op_token, lhs, rhs)
- return lhs
- @staticmethod
- def build_IfExp(ctx, expr):
- return TernaryIf(
- build_expr(ctx, expr.test),
- build_expr(ctx, expr.body),
- build_expr(ctx, expr.orelse),
- )
- @staticmethod
- def build_Compare(ctx, expr):
- operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
- result = None
- for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
- op = type(op_)
- op_token = ExprBuilder.cmpop_map.get(op)
- r = ctx.make_raw_range(lhs.range().end, rhs.range().start)
- if op_token is None:
- raise NotSupportedError(
- r, "unsupported comparison operator: " + op.__name__
- )
- if op == ast.NotIn:
- # NB: `not in` is just `not( in )`, so we don't introduce new tree view
- # but just make it a nested call in our tree view structure
- in_expr = BinOp("in", lhs, rhs)
- cmp_expr = UnaryOp(r, "not", in_expr)
- else:
- cmp_expr = BinOp(op_token, lhs, rhs)
- if result is None:
- result = cmp_expr
- else:
- result = BinOp("and", result, cmp_expr)
- return result
- @staticmethod
- def build_Subscript(ctx, expr):
- def build_SliceExpr(ctx, base, slice_expr):
- lower = (
- build_expr(ctx, slice_expr.lower)
- if slice_expr.lower is not None
- else None
- )
- upper = (
- build_expr(ctx, slice_expr.upper)
- if slice_expr.upper is not None
- else None
- )
- step = (
- build_expr(ctx, slice_expr.step)
- if slice_expr.step is not None
- else None
- )
- return SliceExpr(base.range(), lower, upper, step)
- def build_Index(ctx, base, index_expr):
- if isinstance(index_expr.value, ast.Tuple):
- raise NotSupportedError(
- base.range(),
- "slicing multiple dimensions with tuples not supported yet",
- )
- return build_expr(ctx, index_expr.value)
- def build_ExtSlice(ctx, base, extslice):
- sub_exprs = []
- for expr in extslice.dims:
- sub_type = type(expr)
- if sub_type is ast.Index:
- sub_exprs.append(build_Index(ctx, base, expr))
- elif sub_type is ast.Slice:
- sub_exprs.append(build_SliceExpr(ctx, base, expr))
- elif sub_type is ast.Constant and expr.value is Ellipsis:
- sub_exprs.append(Dots(base.range()))
- else:
- raise NotSupportedError(
- base.range(),
- f"slicing multiple dimensions with {sub_type} not supported",
- )
- return sub_exprs
- base = build_expr(ctx, expr.value)
- sub_type = type(expr.slice)
- if sub_type is ast.Index:
- if isinstance(expr.slice.value, ast.Tuple):
- # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
- # XXX: Indexing using a list is **different**! It triggers advanced indexing.
- indices = [
- build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts
- ]
- if not indices:
- # `col_offset` is an int, but `end_col_offset` is
- # `Optional[int]`. The magic number is here to make
- # sure we can parse `()` on any machine
- r = ctx.make_range(
- expr.lineno,
- expr.slice.value.col_offset,
- expr.slice.value.col_offset + 2,
- )
- tup = TupleLiteral(r, [])
- indices.append(tup)
- return Subscript(base, indices)
- else:
- return Subscript(base, [build_expr(ctx, expr.slice.value)])
- elif sub_type is ast.Slice:
- return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
- elif sub_type is ast.ExtSlice:
- return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
- elif sys.version_info >= (
- 3,
- 9,
- ): # In Python3.9 array indicies are not wrapped in ast.Index
- if sub_type is ast.Tuple:
- # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
- indices = []
- for index_expr in expr.slice.elts:
- if isinstance(index_expr, ast.Slice):
- indices.append(build_SliceExpr(ctx, base, index_expr))
- else:
- indices.append(build_expr(ctx, index_expr))
- # Special-case logic for `typing.Tuple[()]`
- if not indices:
- # See note above r.e. magic number
- r = ctx.make_range(
- expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2
- )
- tup = TupleLiteral(r, [])
- indices.append(tup)
- return Subscript(base, indices)
- return Subscript(base, [build_expr(ctx, expr.slice)])
- else: # Ellipsis (can only happen in Python 2)
- raise NotSupportedError(base.range(), "ellipsis is not supported")
- @staticmethod
- def build_List(ctx, expr):
- return ListLiteral(
- ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
- [build_expr(ctx, e) for e in expr.elts],
- )
- @staticmethod
- def build_Tuple(ctx, expr):
- return TupleLiteral(
- ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
- [build_expr(ctx, e) for e in expr.elts],
- )
- @staticmethod
- def build_Dict(ctx, expr):
- range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
- if expr.keys and not expr.keys[0]:
- raise NotSupportedError(
- range, "Dict expansion (e.g. `{**dict}`) is not supported"
- )
- return DictLiteral(
- range,
- [build_expr(ctx, e) for e in expr.keys],
- [build_expr(ctx, e) for e in expr.values],
- )
- @staticmethod
- def build_Num(ctx, expr):
- value = str(expr.value)
- r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
- return Const(r, value)
- @staticmethod
- def build_Constant(ctx, expr):
- value = expr.value
- if value is None or isinstance(value, bool):
- # NB: this check has to happen before the int check because bool is
- # a subclass of int
- return ExprBuilder.build_NameConstant(ctx, expr)
- if isinstance(value, (int, float, complex)):
- return ExprBuilder.build_Num(ctx, expr)
- elif isinstance(value, str):
- return ExprBuilder.build_Str(ctx, expr)
- elif isinstance(value, type(Ellipsis)):
- return ExprBuilder.build_Ellipsis(ctx, expr)
- else:
- error_range = ctx.make_range(
- expr.lineno, expr.col_offset, expr.col_offset + len(str(value))
- )
- raise FrontendError(error_range, "Unknown Constant expression type")
- @staticmethod
- def build_Str(ctx, expr):
- value = str(expr.value)
- r = ctx.make_range(
- expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1
- )
- return StringLiteral(r, value)
- @staticmethod
- def build_JoinedStr(ctx, expr):
- s = ""
- args = []
- for value in expr.values:
- r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1)
- if isinstance(value, ast.FormattedValue):
- if value.conversion != -1:
- raise NotSupportedError(r, "Don't support conversion in JoinedStr")
- if value.format_spec is not None:
- raise NotSupportedError(r, "Don't support formatting in JoinedStr")
- s += "{}"
- args.append(build_expr(ctx, value.value))
- elif isinstance(value, ast.Constant):
- s += value.value
- else:
- raise NotSupportedError(r, "Unsupported value in JoinedStr")
- r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
- return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, [])
- @staticmethod
- def build_ListComp(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
- if len(stmt.generators) != 1:
- raise NotSupportedError(r, "Only a single generator is currently supported")
- if len(stmt.generators[0].ifs) != 0:
- raise NotSupportedError(r, "Comprehension ifs are not supported yet")
- elt_expr = build_expr(ctx, stmt.elt)
- target_expr = build_expr(ctx, stmt.generators[0].target)
- iter_expr = build_expr(ctx, stmt.generators[0].iter)
- return ListComp(r, elt_expr, target_expr, iter_expr)
- @staticmethod
- def build_GeneratorExp(ctx, stmt):
- # Convert Generator expression to ListComp
- return ExprBuilder.build_ListComp(ctx, stmt)
- @staticmethod
- def build_DictComp(ctx, stmt):
- r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
- if len(stmt.generators) != 1:
- raise NotSupportedError(r, "Only a single generator is currently supported")
- if len(stmt.generators[0].ifs) != 0:
- raise NotSupportedError(r, "Comprehension ifs are not supported yet")
- key_expr = build_expr(ctx, stmt.key)
- value_expr = build_expr(ctx, stmt.value)
- target_expr = build_expr(ctx, stmt.generators[0].target)
- iter_expr = build_expr(ctx, stmt.generators[0].iter)
- return DictComp(r, key_expr, value_expr, target_expr, iter_expr)
- @staticmethod
- def build_Starred(ctx, expr):
- r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
- return Starred(r, build_expr(ctx, expr.value))
- build_expr = ExprBuilder()
- build_stmt = StmtBuilder()
- build_withitem = WithItemBuilder()
- def find_before(ctx, pos, substr, offsets=(0, 0)):
- new_pos = ctx.source[:pos].rindex(substr)
- return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
|