| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # mypy: allow-untyped-defs
- import dataclasses
- import inspect
- import sys
- from typing import Any, Callable, Dict, Iterable, Tuple
- import torch
- import torch._utils_internal as _utils_internal
- from torch import _C
- @dataclasses.dataclass
- class Kernel:
- """Models a (function, source location)"""
- func: Callable
- source: str
- def __call__(self, *args, **kwargs):
- return self.func(*args, **kwargs)
- class RegistrationHandle:
- """Does something when someone calls .destroy() on it"""
- def __init__(self, on_destroy: Callable):
- self._on_destroy = on_destroy
- def destroy(self) -> None:
- self._on_destroy()
- def get_source(stacklevel: int) -> str:
- """Get a string that represents the caller.
- Example: "/path/to/foo.py:42"
- Use stacklevel=1 to get the caller's source
- Use stacklevel=2 to get the caller's caller's source
- etc.
- """
- frame = inspect.getframeinfo(sys._getframe(stacklevel))
- source = f"{frame.filename}:{frame.lineno}"
- return source
- def parse_namespace(qualname: str) -> Tuple[str, str]:
- splits = qualname.split("::")
- if len(splits) != 2:
- raise ValueError(
- f"Expected `qualname` to be of the form "
- f'"namespace::name", but got {qualname}. '
- f"The qualname passed to the torch.library APIs must consist "
- f"of a namespace and a name, e.g. aten::sin"
- )
- return splits[0], splits[1]
- def lookup_op(qualname: str) -> torch._ops.OpOverload:
- namespace, name = parse_namespace(qualname)
- if "." in name:
- name, overload = name.split(".")
- else:
- overload = "default"
- ns = getattr(torch.ops, namespace)
- packet = getattr(ns, name)
- return getattr(packet, overload)
- def is_builtin(op: torch._ops.OpOverload) -> bool:
- assert isinstance(op, torch._ops.OpOverload)
- return op.namespace in {"aten", "prim", "prims"}
- def is_functional_schema(schema: Any) -> bool:
- """Check if the schema is functional.
- An operator is functional if:
- - it does not mutate any of its inputs
- - it does not return a view on any of its inputs
- - it has at least one return
- """
- def is_functional(schema):
- if schema.is_mutable:
- return False
- rets = schema.returns
- is_non_mutating_view = len(rets) > 0 and any(
- r.alias_info is not None and not r.alias_info.is_write for r in rets
- )
- if is_non_mutating_view:
- return False
- if not schema.returns:
- return False
- return True
- if isinstance(schema, torch._C.FunctionSchema):
- return is_functional(schema)
- # Lazy import because not all PyTorch builds have torchgen
- from torchgen.model import FunctionSchema
- if isinstance(schema, str):
- schema = FunctionSchema.parse(schema)
- assert isinstance(schema, FunctionSchema)
- return is_functional(schema)
- # should be torch._C.JitType but that annotation is busted
- def is_tensorlist_like_type(typ: Any) -> bool:
- return (
- typ == _C.ListType(_C.TensorType.get())
- or typ == _C.ListType(_C.OptionalType(_C.TensorType.get()))
- or typ == _C.OptionalType(_C.ListType(_C.TensorType.get()))
- or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get())))
- )
- # should be torch._C.JitType but that annotation is busted
- def is_tensor_like_type(typ: Any) -> bool:
- return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get())
- def mutates_and_returns_first_arg(op: torch._ops.OpOverload):
- """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg.
- TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this,
- but not all PyTorch builds have torchgen (due to the yaml dependency being weird).
- Figure this out.
- Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a)
- """
- if op.namespace != "aten":
- return False
- schema = op._schema
- if not len(schema.returns) == 1:
- return False
- if schema.returns[0].alias_info is None:
- return False
- alias_set = schema.returns[0].alias_info.after_set
- if len(alias_set) != 1:
- return False
- loc = next(iter(alias_set))
- if len(schema.arguments) < 1:
- return False
- first_arg = schema.arguments[0]
- if first_arg.alias_info is None:
- return False
- if not first_arg.alias_info.is_write:
- return False
- alias_set = first_arg.alias_info.after_set
- if len(alias_set) != 1:
- return False
- if loc != next(iter(alias_set)):
- return False
- for arg in schema.arguments[1:]:
- if arg.alias_info is not None:
- return False
- return True
- def fill_defaults(schema, args, kwargs):
- new_args = []
- new_kwargs = {}
- for i in range(len(schema.arguments)):
- info = schema.arguments[i]
- if info.kwarg_only:
- if info.name in kwargs:
- new_kwargs[info.name] = kwargs[info.name]
- else:
- new_kwargs[info.name] = info.default_value
- else:
- if i < len(args):
- new_args.append(args[i])
- else:
- new_args.append(info.default_value)
- return tuple(new_args), new_kwargs
- def zip_schema(
- schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
- ) -> Iterable[Tuple[_C.Argument, Any]]:
- """zips schema.arguments and (args, kwargs) together.
- Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
- that is, kwargs must be keyword-only arguments and default values may be omitted.
- """
- assert len(schema.arguments) >= len(args) + len(kwargs)
- for i in range(len(schema.arguments)):
- info = schema.arguments[i]
- if info.kwarg_only:
- if info.name in kwargs:
- yield info, kwargs[info.name]
- continue
- if i >= len(args):
- # args that are equal to their default values are not populated
- # if they are followed by args that are equal to their defaults.
- # Skip these.
- continue
- yield info, args[i]
- return
- def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool:
- assert isinstance(op, torch._ops.OpOverload)
- if is_builtin(op):
- # We control the built-ins. These may (in rare cases)
- # do input metadata mutation (which we have banned on custom ops)
- return False
- schema = op._schema
- # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
- if not schema.is_mutable:
- return False
- if len(schema.returns) > 0:
- return False
- # If the op returns nothing, then it has a trivial fake impl.
- return True
- def requires_set_python_module() -> bool:
- """If an op was defined in C++ and extended from Python using the
- torch.library APIs, returns if we require that there have been a
- m.set_python_module("mylib.ops") call from C++ that associates
- the C++ op with a python module.
- """
- return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
- def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
- assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
- overload_types = []
- args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
- for a in args_flattened:
- # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
- # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
- # where in one case we only include tensors with the python key, and in another
- # we include **all** tensors.
- if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
- torch._C.DispatchKey.Python
- ):
- overload_types.append(type(a))
- # TODO: check that I got these args correct (in C++, we pass in "0000"??)
- return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
- def has_kwarg_only_args(schema: _C.FunctionSchema):
- return any(a.kwarg_only for a in schema.arguments)
- def has_kwarg_only_tensors(schema: _C.FunctionSchema):
- for a in schema.arguments:
- if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)):
- continue
- if not a.kwarg_only:
- continue
- return True
- return False
|