supported_ops.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import textwrap
  4. import torch.jit
  5. from torch.jit._builtins import _find_builtin
  6. # this file is for generating documentation using sphinx autodoc
  7. # > help(torch.jit.supported_ops) will also give a nice listed of the
  8. # supported ops programmatically
  9. def _hidden(name):
  10. return name.startswith("_") and not name.startswith("__")
  11. def _emit_type(type):
  12. return str(type)
  13. def _emit_arg(indent, i, arg):
  14. v = f"{arg.name} : {_emit_type(arg.type)}"
  15. default = arg.default_value
  16. if default is not None:
  17. v = f"{v}={str(default)}"
  18. if i > 0:
  19. v = f"\n{' ' * indent}{v}"
  20. return v
  21. def _emit_args(indent, arguments):
  22. return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
  23. def _emit_ret(ret):
  24. return _emit_type(ret.type)
  25. def _emit_rets(returns):
  26. if len(returns) == 1:
  27. return _emit_ret(returns[0])
  28. return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]"
  29. def _emit_schema(mod, name, schema, arg_start=0, padding=4):
  30. if mod is None:
  31. qualified_name = name
  32. else:
  33. qualified_name = f"{mod}.{name}"
  34. schema_str = (
  35. f"{qualified_name}"
  36. f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) "
  37. f"-> {_emit_rets(schema.returns)}"
  38. )
  39. return schema_str
  40. def _get_tensor_ops():
  41. def is_tensor_method(schema):
  42. if len(schema.arguments) == 0:
  43. return False
  44. self = schema.arguments[0]
  45. if self.name != "self":
  46. return False
  47. if not self.type.isSubtypeOf(torch._C.TensorType.get()):
  48. return False
  49. return True
  50. methods = []
  51. # discover methods
  52. for elem in dir(torch.Tensor):
  53. if not _hidden(elem):
  54. schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
  55. for schema in schemas:
  56. if is_tensor_method(schema):
  57. methods.append(_emit_schema("Tensor", elem, schema, arg_start=1))
  58. return "Supported Tensor Methods", methods
  59. def _get_nn_functional_ops():
  60. functions = []
  61. # Iterate over torch.nn.functional
  62. mod = torch.nn.functional
  63. name = mod.__name__
  64. for elem in dir(torch.nn.functional):
  65. attr = getattr(mod, elem)
  66. if not inspect.isfunction(attr) or _hidden(elem[0]):
  67. # Ignore non-functions and internal methods
  68. continue
  69. attr_module = inspect.getmodule(attr)
  70. if not attr_module:
  71. raise RuntimeError(f"Module for {attr} not found")
  72. if "torch.nn.functional" not in attr_module.__name__:
  73. # Ignore functions from outside torch.nn.functional
  74. continue
  75. try:
  76. # compile fn, get schema
  77. scripted = torch.jit.script(attr)
  78. scripted_schema = scripted.schema
  79. functions.append(_emit_schema(name, elem, scripted_schema))
  80. except: # noqa: B001,E722
  81. # Skip interpolate / boolean dispatched things
  82. pass
  83. # Iterate over modules that we know contain a lot of builtins
  84. for mod in torch.jit._builtins._modules_containing_builtins:
  85. name = mod.__name__
  86. for elem in dir(mod):
  87. builtin = _find_builtin(getattr(mod, elem))
  88. if builtin is not None:
  89. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  90. for schema in schemas:
  91. # remove _tan but not __and__
  92. if not _hidden(elem):
  93. functions.append(_emit_schema(name, elem, schema))
  94. return "Supported PyTorch Functions", functions
  95. def _get_builtins_helper():
  96. builtins = []
  97. for fn, _builtin_name in torch.jit._builtins._builtin_ops:
  98. mod = inspect.getmodule(fn)
  99. if not hasattr(fn, "__name__"):
  100. # typing classes
  101. continue
  102. if not mod:
  103. continue
  104. if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
  105. # skip internal-only methods
  106. continue
  107. if "torch._C" in mod.__name__:
  108. continue
  109. builtins.append((fn, _builtin_name))
  110. return builtins
  111. def _is_math_fn(fn):
  112. mod = inspect.getmodule(fn)
  113. if not mod:
  114. raise RuntimeError(f"Module for {fn} not found")
  115. return mod.__name__ == "math"
  116. def _get_torchscript_builtins():
  117. functions = []
  118. builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
  119. builtins_list = list(builtins)
  120. # Iterate over the specially added builtins
  121. for fn, _builtin_name in builtins_list:
  122. mod = inspect.getmodule(fn)
  123. if not mod:
  124. raise RuntimeError(f"Module for {fn} not found")
  125. builtin = _find_builtin(fn)
  126. if builtin is not None:
  127. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  128. for schema in schemas:
  129. functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
  130. pass
  131. return "TorchScript Builtin Functions", functions
  132. def _get_math_builtins():
  133. functions = []
  134. builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
  135. builtins_list = list(builtins)
  136. # Iterate over the specially added builtins
  137. for fn, _builtin_name in builtins_list:
  138. mod = inspect.getmodule(fn)
  139. if not mod:
  140. raise RuntimeError(f"Module for {fn} not found")
  141. builtin = _find_builtin(fn)
  142. if builtin is not None:
  143. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  144. for schema in schemas:
  145. schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
  146. if "Tensor" in schema_str:
  147. # Skip Tensor ops that have the same name as math functions
  148. # (they will show up in the tensor methods section)
  149. continue
  150. functions.append(schema)
  151. pass
  152. return "``math`` Module", functions
  153. def _get_global_builtins():
  154. # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
  155. supported_builtins = [
  156. "print",
  157. "tuple",
  158. "float",
  159. "complex",
  160. "int",
  161. "bool",
  162. "str",
  163. "getattr",
  164. "hasattr",
  165. "isinstance",
  166. "len",
  167. "hex",
  168. "oct",
  169. "round",
  170. "hash",
  171. "min",
  172. "max",
  173. "abs",
  174. "all",
  175. "divmod",
  176. "list",
  177. "ord",
  178. "chr",
  179. "bin",
  180. "range",
  181. "zip",
  182. "enumerate",
  183. "sorted",
  184. ]
  185. op_renames = {
  186. "bool": "aten::Bool",
  187. "int": "aten::Int",
  188. "float": "aten::Float",
  189. "complex": "aten::Complex",
  190. "abs": "prim::abs",
  191. "max": "prim::max",
  192. "min": "prim::min",
  193. "range": "fake::does_not_exist",
  194. }
  195. schemaless_op_explanations = {
  196. "print": "Print any value",
  197. "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known",
  198. "getattr": "Attribute name must be a literal string",
  199. "hasattr": "Attribute name must be a literal string",
  200. "isinstance": "Result is static",
  201. "zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
  202. "enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
  203. "range": "Can only be used as an iterator in a for loop",
  204. }
  205. magic_methods = [
  206. ("complex", "__complex__"),
  207. ("float", "__float__"),
  208. ("int", "__int__"),
  209. ("bool", "__bool__"),
  210. ("str", "__str__"),
  211. ("len", "__len__"),
  212. ("hex", "__hex__"),
  213. ("oct", "__oct__"),
  214. ]
  215. magic_methods_rows = []
  216. for fn, magic_method in magic_methods:
  217. magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
  218. schematized_ops = []
  219. schemaless_ops = []
  220. for fn in supported_builtins:
  221. op_name = f"aten::{fn}"
  222. if fn in op_renames:
  223. op_name = op_renames[fn]
  224. schemas = torch._C._jit_get_schemas_for_operator(op_name)
  225. for s in schemas:
  226. schematized_ops.append(_emit_schema(None, fn, s, padding=0))
  227. if len(schemas) > 0:
  228. schematized_ops.append("")
  229. else:
  230. table_row = (
  231. f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
  232. )
  233. schemaless_ops.append(table_row)
  234. schematized_ops_str = "\n".join(schematized_ops)
  235. schemaless_ops_str = "\n".join(schemaless_ops)
  236. magic_methods_rows_str = "\n".join(magic_methods_rows)
  237. schematized_ops_str = textwrap.indent(schematized_ops_str, "\t")
  238. schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t")
  239. magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t")
  240. section = f"""
  241. The functions in the following table are supported but do not have a static schema
  242. .. csv-table::
  243. :header: "Function", "Note"
  244. {schemaless_ops_str}
  245. The following functions will use the corresponding magic method on :any:`TorchScript classes`
  246. .. csv-table::
  247. :header: "Function", "Magic Method"
  248. {magic_methods_rows_str}
  249. These built-in functions use the schema
  250. .. rst-class:: codeblock-height-limiter
  251. ::
  252. {schematized_ops_str}
  253. """
  254. return "Python Built-in Functions", section
  255. def _list_supported_ops():
  256. def emit_block(decls):
  257. return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format(
  258. "".join(f" {d}\n\n" for d in decls)
  259. )
  260. body = ""
  261. op_gathering_fns = (
  262. _get_tensor_ops,
  263. _get_nn_functional_ops,
  264. _get_torchscript_builtins,
  265. _get_global_builtins,
  266. _get_math_builtins,
  267. )
  268. for fn in op_gathering_fns:
  269. header, items = fn()
  270. link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-")
  271. if isinstance(items, str):
  272. section = f"{header}\n{'~' * len(header)}\n{items}\n"
  273. else:
  274. section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}"
  275. section = f".. _{link_target}:" + "\n\n" + section
  276. body += section
  277. return body
  278. __doc__ = _list_supported_ops()