| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- # mypy: allow-untyped-defs
- import ast
- import builtins
- import dis
- import enum
- import inspect
- import re
- import typing
- import warnings
- from textwrap import dedent
- from typing import Type
- import torch
- from torch._C import (
- _GeneratorType,
- AnyType,
- AwaitType,
- BoolType,
- ComplexType,
- DeviceObjType,
- DictType,
- EnumType,
- FloatType,
- FutureType,
- InterfaceType,
- IntType,
- ListType,
- NoneType,
- NumberType,
- OptionalType,
- StreamObjType,
- StringType,
- TensorType,
- TupleType,
- UnionType,
- )
- from torch._sources import get_source_lines_and_file
- from .._jit_internal import ( # type: ignore[attr-defined]
- _Await,
- _qualified_name,
- Any,
- BroadcastingList1,
- BroadcastingList2,
- BroadcastingList3,
- Dict,
- Future,
- is_await,
- is_dict,
- is_future,
- is_ignored_fn,
- is_list,
- is_optional,
- is_tuple,
- is_union,
- List,
- Optional,
- Tuple,
- Union,
- )
- from ._state import _get_script_class
- if torch.distributed.rpc.is_available():
- from torch._C import RRefType
- from .._jit_internal import is_rref, RRef
- from torch._ops import OpOverloadPacket
- class Module:
- def __init__(self, name, members):
- self.name = name
- self.members = members
- def __getattr__(self, name):
- try:
- return self.members[name]
- except KeyError:
- raise RuntimeError(
- f"Module {self.name} has no member called {name}"
- ) from None
- class EvalEnv:
- env = {
- "torch": Module("torch", {"Tensor": torch.Tensor}),
- "Tensor": torch.Tensor,
- "typing": Module("typing", {"Tuple": Tuple}),
- "Tuple": Tuple,
- "List": List,
- "Dict": Dict,
- "Optional": Optional,
- "Union": Union,
- "Future": Future,
- "Await": _Await,
- }
- def __init__(self, rcb):
- self.rcb = rcb
- if torch.distributed.rpc.is_available():
- self.env["RRef"] = RRef
- def __getitem__(self, name):
- if name in self.env:
- return self.env[name]
- if self.rcb is not None:
- return self.rcb(name)
- return getattr(builtins, name, None)
- def get_signature(fn, rcb, loc, is_method):
- if isinstance(fn, OpOverloadPacket):
- signature = try_real_annotations(fn.op, loc)
- else:
- signature = try_real_annotations(fn, loc)
- if signature is not None and is_method:
- # If this is a method, then the signature will include a type for
- # `self`, but type comments do not contain a `self`. So strip it
- # away here so everything is consistent (`inspect.ismethod` does
- # not work here since `fn` is unbound at this point)
- param_types, return_type = signature
- param_types = param_types[1:]
- signature = (param_types, return_type)
- if signature is None:
- type_line, source = None, None
- try:
- source = dedent("".join(get_source_lines_and_file(fn)[0]))
- type_line = get_type_line(source)
- except TypeError:
- pass
- # This might happen both because we failed to get the source of fn, or
- # because it didn't have any annotations.
- if type_line is not None:
- signature = parse_type_line(type_line, rcb, loc)
- return signature
- def is_function_or_method(the_callable):
- # A stricter version of `inspect.isroutine` that does not pass for built-in
- # functions
- return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
- def is_vararg(the_callable):
- if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004
- # If `the_callable` is a class, de-sugar the call so we can still get
- # the signature
- the_callable = the_callable.__call__
- if is_function_or_method(the_callable):
- return inspect.getfullargspec(the_callable).varargs is not None
- else:
- return False
- def get_param_names(fn, n_args):
- if isinstance(fn, OpOverloadPacket):
- fn = fn.op
- if (
- not is_function_or_method(fn)
- and callable(fn)
- and is_function_or_method(fn.__call__)
- ): # noqa: B004
- # De-sugar calls to classes
- fn = fn.__call__
- if is_function_or_method(fn):
- if is_ignored_fn(fn):
- fn = inspect.unwrap(fn)
- return inspect.getfullargspec(fn).args
- else:
- # The `fn` was not a method or function (maybe a class with a __call__
- # method, so use a default param name list)
- return [str(i) for i in range(n_args)]
- def check_fn(fn, loc):
- # Make sure the function definition is not a class instantiation
- try:
- source = dedent("".join(get_source_lines_and_file(fn)[0]))
- except (OSError, TypeError):
- return
- if source is None:
- return
- py_ast = ast.parse(source)
- if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
- raise torch.jit.frontend.FrontendError(
- loc,
- f"Cannot instantiate class '{py_ast.body[0].name}' in a script function",
- )
- if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
- raise torch.jit.frontend.FrontendError(
- loc, "Expected a single top-level function"
- )
- def _eval_no_call(stmt, glob, loc):
- """Evaluate statement as long as it does not contain any method/function calls."""
- bytecode = compile(stmt, "", mode="eval")
- for insn in dis.get_instructions(bytecode):
- if "CALL" in insn.opname:
- raise RuntimeError(
- f"Type annotation should not contain calls, but '{stmt}' does"
- )
- return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
- def parse_type_line(type_line, rcb, loc):
- """Parse a type annotation specified as a comment.
- Example inputs:
- # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
- # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
- """
- arg_ann_str, ret_ann_str = split_type_line(type_line)
- try:
- arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
- except (NameError, SyntaxError) as e:
- raise RuntimeError(
- "Failed to parse the argument list of a type annotation"
- ) from e
- if not isinstance(arg_ann, tuple):
- arg_ann = (arg_ann,)
- try:
- ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
- except (NameError, SyntaxError) as e:
- raise RuntimeError(
- "Failed to parse the return type of a type annotation"
- ) from e
- arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
- return arg_types, ann_to_type(ret_ann, loc)
- def get_type_line(source):
- """Try to find the line containing a comment with the type annotation."""
- type_comment = "# type:"
- lines = source.split("\n")
- lines = list(enumerate(lines))
- type_lines = list(filter(lambda line: type_comment in line[1], lines))
- # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
- # to the hack in torch/_VF.py.
- # An ignore type comment can be of following format:
- # 1) type: ignore
- # 2) type: ignore[rule-code]
- # This ignore statement must be at the end of the line
- # adding an extra backslash before the space, to avoid triggering
- # one of the checks in .github/workflows/lint.yml
- type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
- type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines))
- if len(type_lines) == 0:
- # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
- wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
- wrong_type_lines = list(
- filter(lambda line: wrong_type_pattern.search(line[1]), lines)
- )
- if len(wrong_type_lines) > 0:
- raise RuntimeError(
- "The annotation prefix in line "
- + str(wrong_type_lines[0][0])
- + " is probably invalid.\nIt must be '# type:'"
- + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950
- + "\nfor examples"
- )
- return None
- elif len(type_lines) == 1:
- # Only 1 type line, quit now
- return type_lines[0][1].strip()
- # Parse split up argument types according to PEP 484
- # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
- return_line = None
- parameter_type_lines = []
- for line_num, line in type_lines:
- if "# type: (...) -> " in line:
- return_line = (line_num, line)
- break
- elif type_comment in line:
- parameter_type_lines.append(line)
- if return_line is None:
- raise RuntimeError(
- "Return type line '# type: (...) -> ...' not found on multiline "
- "type annotation\nfor type lines:\n"
- + "\n".join([line[1] for line in type_lines])
- + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"
- )
- def get_parameter_type(line):
- item_type = line[line.find(type_comment) + len(type_comment) :]
- return item_type.strip()
- types = map(get_parameter_type, parameter_type_lines)
- parameter_types = ", ".join(types)
- return return_line[1].replace("...", parameter_types)
- def split_type_line(type_line):
- """Split the comment with the type annotation into parts for argument and return types.
- For example, for an input of:
- # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
- This function will return:
- ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
- """
- start_offset = len("# type:")
- try:
- arrow_pos = type_line.index("->")
- except ValueError:
- raise RuntimeError(
- "Syntax error in type annotation (couldn't find `->`)"
- ) from None
- return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip()
- def try_real_annotations(fn, loc):
- """Try to use the Py3.5+ annotation syntax to get the type."""
- try:
- # Note: anything annotated as `Optional[T]` will automatically
- # be returned as `Union[T, None]` per
- # https://github.com/python/typing/blob/master/src/typing.py#L850
- sig = inspect.signature(fn)
- except ValueError:
- return None
- all_annots = [sig.return_annotation] + [
- p.annotation for p in sig.parameters.values()
- ]
- if all(ann is sig.empty for ann in all_annots):
- return None
- arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()]
- return_type = ann_to_type(sig.return_annotation, loc)
- return arg_types, return_type
- # Finds common type for enum values belonging to an Enum class. If not all
- # values have the same type, AnyType is returned.
- def get_enum_value_type(e: Type[enum.Enum], loc):
- enum_values: List[enum.Enum] = list(e)
- if not enum_values:
- raise ValueError(f"No enum values defined for: '{e.__class__}'")
- types = {type(v.value) for v in enum_values}
- ir_types = [try_ann_to_type(t, loc) for t in types]
- # If Enum values are of different types, an exception will be raised here.
- # Even though Python supports this case, we chose to not implement it to
- # avoid overcomplicate logic here for a rare use case. Please report a
- # feature request if you find it necessary.
- res = torch._C.unify_type_list(ir_types)
- if not res:
- return AnyType.get()
- return res
- def is_tensor(ann):
- if issubclass(ann, torch.Tensor):
- return True
- if issubclass(
- ann,
- (
- torch.LongTensor,
- torch.DoubleTensor,
- torch.FloatTensor,
- torch.IntTensor,
- torch.ShortTensor,
- torch.HalfTensor,
- torch.CharTensor,
- torch.ByteTensor,
- torch.BoolTensor,
- ),
- ):
- warnings.warn(
- "TorchScript will treat type annotations of Tensor "
- "dtype-specific subtypes as if they are normal Tensors. "
- "dtype constraints are not enforced in compilation either."
- )
- return True
- return False
- def _fake_rcb(inp):
- return None
- def try_ann_to_type(ann, loc, rcb=None):
- ann_args = typing.get_args(ann) # always returns a tuple!
- if ann is inspect.Signature.empty:
- return TensorType.getInferred()
- if ann is None:
- return NoneType.get()
- if inspect.isclass(ann) and is_tensor(ann):
- return TensorType.get()
- if is_tuple(ann):
- # Special case for the empty Tuple type annotation `Tuple[()]`
- if len(ann_args) == 1 and ann_args[0] == ():
- return TupleType([])
- return TupleType([try_ann_to_type(a, loc) for a in ann_args])
- if is_list(ann):
- elem_type = try_ann_to_type(ann_args[0], loc)
- if elem_type:
- return ListType(elem_type)
- if is_dict(ann):
- key = try_ann_to_type(ann_args[0], loc)
- value = try_ann_to_type(ann_args[1], loc)
- # Raise error if key or value is None
- if key is None:
- raise ValueError(
- f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}"
- )
- if value is None:
- raise ValueError(
- f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}"
- )
- return DictType(key, value)
- if is_optional(ann):
- if issubclass(ann_args[1], type(None)):
- contained = ann_args[0]
- else:
- contained = ann_args[1]
- valid_type = try_ann_to_type(contained, loc)
- msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
- assert valid_type, msg.format(repr(ann), repr(contained), repr(loc))
- return OptionalType(valid_type)
- if is_union(ann):
- # TODO: this is hack to recognize NumberType
- if set(ann_args) == {int, float, complex}:
- return NumberType.get()
- inner: List = []
- # We need these extra checks because both `None` and invalid
- # values will return `None`
- # TODO: Determine if the other cases need to be fixed as well
- for a in typing.get_args(ann):
- if a is None:
- inner.append(NoneType.get())
- maybe_type = try_ann_to_type(a, loc)
- msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
- assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc))
- inner.append(maybe_type)
- return UnionType(inner) # type: ignore[arg-type]
- if torch.distributed.rpc.is_available() and is_rref(ann):
- return RRefType(try_ann_to_type(ann_args[0], loc))
- if is_future(ann):
- return FutureType(try_ann_to_type(ann_args[0], loc))
- if is_await(ann):
- elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get()
- return AwaitType(elementType)
- if ann is float:
- return FloatType.get()
- if ann is complex:
- return ComplexType.get()
- if ann is int or ann is torch.SymInt:
- return IntType.get()
- if ann is str:
- return StringType.get()
- if ann is bool:
- return BoolType.get()
- if ann is Any:
- return AnyType.get()
- if ann is type(None):
- return NoneType.get()
- if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
- return InterfaceType(ann.__torch_script_interface__)
- if ann is torch.device:
- return DeviceObjType.get()
- if ann is torch.Generator:
- return _GeneratorType.get()
- if ann is torch.Stream:
- return StreamObjType.get()
- if ann is torch.dtype:
- return IntType.get() # dtype not yet bound in as its own type
- if inspect.isclass(ann) and issubclass(ann, enum.Enum):
- if _get_script_class(ann) is None:
- scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
- name = scripted_class.qualified_name()
- else:
- name = _qualified_name(ann)
- return EnumType(name, get_enum_value_type(ann, loc), list(ann))
- if inspect.isclass(ann):
- maybe_script_class = _get_script_class(ann)
- if maybe_script_class is not None:
- return maybe_script_class
- if torch._jit_internal.can_compile_class(ann):
- return torch.jit._script._recursive_compile_class(ann, loc)
- # Maybe resolve a NamedTuple to a Tuple Type
- if rcb is None:
- rcb = _fake_rcb
- return torch._C._resolve_type_from_object(ann, loc, rcb)
- def ann_to_type(ann, loc, rcb=None):
- the_type = try_ann_to_type(ann, loc, rcb)
- if the_type is not None:
- return the_type
- raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
- __all__ = [
- "Any",
- "List",
- "BroadcastingList1",
- "BroadcastingList2",
- "BroadcastingList3",
- "Tuple",
- "is_tuple",
- "is_list",
- "Dict",
- "is_dict",
- "is_optional",
- "is_union",
- "TensorType",
- "TupleType",
- "FloatType",
- "ComplexType",
- "IntType",
- "ListType",
- "StringType",
- "DictType",
- "AnyType",
- "Module",
- # TODO: Consider not exporting these during wildcard import (reserve
- # that for the types; for idiomatic typing code.)
- "get_signature",
- "check_fn",
- "get_param_names",
- "parse_type_line",
- "get_type_line",
- "split_type_line",
- "try_real_annotations",
- "try_ann_to_type",
- "ann_to_type",
- ]
|