| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- # mypy: allow-untyped-defs
- import inspect
- import textwrap
- import torch.jit
- from torch.jit._builtins import _find_builtin
- # this file is for generating documentation using sphinx autodoc
- # > help(torch.jit.supported_ops) will also give a nice listed of the
- # supported ops programmatically
- def _hidden(name):
- return name.startswith("_") and not name.startswith("__")
- def _emit_type(type):
- return str(type)
- def _emit_arg(indent, i, arg):
- v = f"{arg.name} : {_emit_type(arg.type)}"
- default = arg.default_value
- if default is not None:
- v = f"{v}={str(default)}"
- if i > 0:
- v = f"\n{' ' * indent}{v}"
- return v
- def _emit_args(indent, arguments):
- return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
- def _emit_ret(ret):
- return _emit_type(ret.type)
- def _emit_rets(returns):
- if len(returns) == 1:
- return _emit_ret(returns[0])
- return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]"
- def _emit_schema(mod, name, schema, arg_start=0, padding=4):
- if mod is None:
- qualified_name = name
- else:
- qualified_name = f"{mod}.{name}"
- schema_str = (
- f"{qualified_name}"
- f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) "
- f"-> {_emit_rets(schema.returns)}"
- )
- return schema_str
- def _get_tensor_ops():
- def is_tensor_method(schema):
- if len(schema.arguments) == 0:
- return False
- self = schema.arguments[0]
- if self.name != "self":
- return False
- if not self.type.isSubtypeOf(torch._C.TensorType.get()):
- return False
- return True
- methods = []
- # discover methods
- for elem in dir(torch.Tensor):
- if not _hidden(elem):
- schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
- for schema in schemas:
- if is_tensor_method(schema):
- methods.append(_emit_schema("Tensor", elem, schema, arg_start=1))
- return "Supported Tensor Methods", methods
- def _get_nn_functional_ops():
- functions = []
- # Iterate over torch.nn.functional
- mod = torch.nn.functional
- name = mod.__name__
- for elem in dir(torch.nn.functional):
- attr = getattr(mod, elem)
- if not inspect.isfunction(attr) or _hidden(elem[0]):
- # Ignore non-functions and internal methods
- continue
- attr_module = inspect.getmodule(attr)
- if not attr_module:
- raise RuntimeError(f"Module for {attr} not found")
- if "torch.nn.functional" not in attr_module.__name__:
- # Ignore functions from outside torch.nn.functional
- continue
- try:
- # compile fn, get schema
- scripted = torch.jit.script(attr)
- scripted_schema = scripted.schema
- functions.append(_emit_schema(name, elem, scripted_schema))
- except: # noqa: B001,E722
- # Skip interpolate / boolean dispatched things
- pass
- # Iterate over modules that we know contain a lot of builtins
- for mod in torch.jit._builtins._modules_containing_builtins:
- name = mod.__name__
- for elem in dir(mod):
- builtin = _find_builtin(getattr(mod, elem))
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- # remove _tan but not __and__
- if not _hidden(elem):
- functions.append(_emit_schema(name, elem, schema))
- return "Supported PyTorch Functions", functions
- def _get_builtins_helper():
- builtins = []
- for fn, _builtin_name in torch.jit._builtins._builtin_ops:
- mod = inspect.getmodule(fn)
- if not hasattr(fn, "__name__"):
- # typing classes
- continue
- if not mod:
- continue
- if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
- # skip internal-only methods
- continue
- if "torch._C" in mod.__name__:
- continue
- builtins.append((fn, _builtin_name))
- return builtins
- def _is_math_fn(fn):
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f"Module for {fn} not found")
- return mod.__name__ == "math"
- def _get_torchscript_builtins():
- functions = []
- builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
- builtins_list = list(builtins)
- # Iterate over the specially added builtins
- for fn, _builtin_name in builtins_list:
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f"Module for {fn} not found")
- builtin = _find_builtin(fn)
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
- pass
- return "TorchScript Builtin Functions", functions
- def _get_math_builtins():
- functions = []
- builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
- builtins_list = list(builtins)
- # Iterate over the specially added builtins
- for fn, _builtin_name in builtins_list:
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f"Module for {fn} not found")
- builtin = _find_builtin(fn)
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
- if "Tensor" in schema_str:
- # Skip Tensor ops that have the same name as math functions
- # (they will show up in the tensor methods section)
- continue
- functions.append(schema)
- pass
- return "``math`` Module", functions
- def _get_global_builtins():
- # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
- supported_builtins = [
- "print",
- "tuple",
- "float",
- "complex",
- "int",
- "bool",
- "str",
- "getattr",
- "hasattr",
- "isinstance",
- "len",
- "hex",
- "oct",
- "round",
- "hash",
- "min",
- "max",
- "abs",
- "all",
- "divmod",
- "list",
- "ord",
- "chr",
- "bin",
- "range",
- "zip",
- "enumerate",
- "sorted",
- ]
- op_renames = {
- "bool": "aten::Bool",
- "int": "aten::Int",
- "float": "aten::Float",
- "complex": "aten::Complex",
- "abs": "prim::abs",
- "max": "prim::max",
- "min": "prim::min",
- "range": "fake::does_not_exist",
- }
- schemaless_op_explanations = {
- "print": "Print any value",
- "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known",
- "getattr": "Attribute name must be a literal string",
- "hasattr": "Attribute name must be a literal string",
- "isinstance": "Result is static",
- "zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
- "enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
- "range": "Can only be used as an iterator in a for loop",
- }
- magic_methods = [
- ("complex", "__complex__"),
- ("float", "__float__"),
- ("int", "__int__"),
- ("bool", "__bool__"),
- ("str", "__str__"),
- ("len", "__len__"),
- ("hex", "__hex__"),
- ("oct", "__oct__"),
- ]
- magic_methods_rows = []
- for fn, magic_method in magic_methods:
- magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
- schematized_ops = []
- schemaless_ops = []
- for fn in supported_builtins:
- op_name = f"aten::{fn}"
- if fn in op_renames:
- op_name = op_renames[fn]
- schemas = torch._C._jit_get_schemas_for_operator(op_name)
- for s in schemas:
- schematized_ops.append(_emit_schema(None, fn, s, padding=0))
- if len(schemas) > 0:
- schematized_ops.append("")
- else:
- table_row = (
- f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
- )
- schemaless_ops.append(table_row)
- schematized_ops_str = "\n".join(schematized_ops)
- schemaless_ops_str = "\n".join(schemaless_ops)
- magic_methods_rows_str = "\n".join(magic_methods_rows)
- schematized_ops_str = textwrap.indent(schematized_ops_str, "\t")
- schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t")
- magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t")
- section = f"""
- The functions in the following table are supported but do not have a static schema
- .. csv-table::
- :header: "Function", "Note"
- {schemaless_ops_str}
- The following functions will use the corresponding magic method on :any:`TorchScript classes`
- .. csv-table::
- :header: "Function", "Magic Method"
- {magic_methods_rows_str}
- These built-in functions use the schema
- .. rst-class:: codeblock-height-limiter
- ::
- {schematized_ops_str}
- """
- return "Python Built-in Functions", section
- def _list_supported_ops():
- def emit_block(decls):
- return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format(
- "".join(f" {d}\n\n" for d in decls)
- )
- body = ""
- op_gathering_fns = (
- _get_tensor_ops,
- _get_nn_functional_ops,
- _get_torchscript_builtins,
- _get_global_builtins,
- _get_math_builtins,
- )
- for fn in op_gathering_fns:
- header, items = fn()
- link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-")
- if isinstance(items, str):
- section = f"{header}\n{'~' * len(header)}\n{items}\n"
- else:
- section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}"
- section = f".. _{link_target}:" + "\n\n" + section
- body += section
- return body
- __doc__ = _list_supported_ops()
|