_builtins.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # mypy: allow-untyped-defs
  2. import cmath
  3. import math
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Dict, Optional
  7. import torch
  8. import torch.backends.cudnn as cudnn
  9. from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple
  10. _builtin_table: Optional[Dict[int, str]] = None
  11. _modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
  12. _builtin_ops = [
  13. # Pairs of (function, op_name)
  14. (_pair, "aten::_pair"),
  15. (_quadruple, "aten::_quadruple"),
  16. (_single, "aten::_single"),
  17. (_triple, "aten::_triple"),
  18. (_list_with_default, "aten::list_with_default"),
  19. (OrderedDict, "aten::dict"),
  20. (dict, "aten::dict"),
  21. (cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
  22. (math.ceil, "aten::ceil"),
  23. (math.copysign, "aten::copysign"),
  24. (math.erf, "aten::erf"),
  25. (math.erfc, "aten::erfc"),
  26. (math.exp, "aten::exp"),
  27. (math.expm1, "aten::expm1"),
  28. (math.fabs, "aten::fabs"),
  29. (math.floor, "aten::floor"),
  30. (math.gamma, "aten::gamma"),
  31. (math.lgamma, "aten::lgamma"),
  32. (math.log, "aten::log"),
  33. (math.log10, "aten::log10"),
  34. (math.log1p, "aten::log1p"),
  35. (math.pow, "aten::pow"),
  36. (math.sqrt, "aten::sqrt"),
  37. (math.isnan, "aten::isnan"),
  38. (math.asinh, "aten::asinh"),
  39. (math.atanh, "aten::atanh"),
  40. (math.cosh, "aten::cosh"),
  41. (math.sinh, "aten::sinh"),
  42. (math.tanh, "aten::tanh"),
  43. (math.acos, "aten::acos"),
  44. (math.asin, "aten::asin"),
  45. (math.atan, "aten::atan"),
  46. (math.atan2, "aten::atan2"),
  47. (math.cos, "aten::cos"),
  48. (math.sin, "aten::sin"),
  49. (math.tan, "aten::tan"),
  50. (math.asinh, "aten::asinh"),
  51. (math.atanh, "aten::atanh"),
  52. (math.acosh, "aten::acosh"),
  53. (math.fmod, "aten::fmod"),
  54. (math.modf, "aten::modf"),
  55. (math.factorial, "aten::factorial"),
  56. (math.frexp, "aten::frexp"),
  57. (math.isinf, "aten::isinf"),
  58. (math.degrees, "aten::degrees"),
  59. (math.radians, "aten::radians"),
  60. (cmath.isnan, "aten::isnan"),
  61. (cmath.isfinite, "aten::isfinite"),
  62. (cmath.isinf, "aten::isinf"),
  63. (cmath.phase, "aten::angle"),
  64. (cmath.rect, "aten::polar"),
  65. (cmath.log, "aten::log"),
  66. (cmath.log10, "aten::log10"),
  67. (cmath.sqrt, "aten::sqrt"),
  68. (cmath.exp, "aten::exp"),
  69. (cmath.sin, "aten::sin"),
  70. (cmath.tan, "aten::tan"),
  71. (cmath.cos, "aten::cos"),
  72. (cmath.asin, "aten::asin"),
  73. (cmath.acos, "aten::acos"),
  74. (cmath.atan, "aten::atan"),
  75. (cmath.sinh, "aten::sinh"),
  76. (cmath.cosh, "aten::cosh"),
  77. (cmath.tanh, "aten::tanh"),
  78. (cmath.asinh, "aten::asinh"),
  79. (cmath.acosh, "aten::acosh"),
  80. (cmath.atanh, "aten::atanh"),
  81. (math.ldexp, "aten::ldexp"),
  82. (torch._assert, "aten::_assert"),
  83. (torch.autograd.grad, "aten::grad"),
  84. (torch.autograd.backward, "aten::backward"),
  85. (torch._C._infer_size, "aten::_infer_size"),
  86. (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined]
  87. (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
  88. (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
  89. (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
  90. (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
  91. (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
  92. (torch._C._get_tracing_state, "aten::_get_tracing_state"),
  93. (torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
  94. (warnings.warn, "aten::warn"),
  95. (torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
  96. (torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]
  97. (torch._VF.cdist, "aten::cdist"), # type: ignore[attr-defined]
  98. (torch._VF.norm, "aten::norm"), # type: ignore[attr-defined]
  99. (torch._VF.unique_dim, "aten::unique_dim"),
  100. (torch._VF.unique_consecutive, "aten::unique_consecutive"), # type: ignore[attr-defined]
  101. (torch._VF.nuclear_norm, "aten::nuclear_norm"),
  102. (torch._VF.frobenius_norm, "aten::frobenius_norm"),
  103. (torch._VF.tensordot, "aten::tensordot"), # type: ignore[attr-defined]
  104. ]
  105. # ops in torch.functional are bound to torch
  106. # in these cases, we want to resolve the function to their python implementation
  107. # instead looking up a builtin "aten::" schema
  108. def _gen_torch_functional_registered_ops():
  109. # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
  110. # but we are currently only able to compile some of the functions. additionally,
  111. # some functions directly map to their aten:: implementations.
  112. # TODO: add support for more ops
  113. ops = [
  114. "stft",
  115. "istft",
  116. "lu",
  117. "cdist",
  118. "norm",
  119. "unique",
  120. "unique_consecutive",
  121. "tensordot",
  122. ]
  123. return {getattr(torch.functional, name) for name in ops}
  124. _functional_registered_ops = _gen_torch_functional_registered_ops()
  125. def _is_special_functional_bound_op(fn):
  126. return fn in _functional_registered_ops
  127. # lazily built to ensure the correct initialization order
  128. def _get_builtin_table():
  129. global _builtin_table
  130. if _builtin_table is not None:
  131. return _builtin_table
  132. _builtin_table = {}
  133. def register_all(mod):
  134. for name in dir(mod):
  135. v = getattr(mod, name)
  136. if (
  137. callable(v)
  138. and not _is_special_functional_bound_op(v)
  139. and v is not torch.no_grad
  140. and v is not torch.autocast
  141. ):
  142. # Fixup inconsistency in segment_reduce
  143. if name == "_segment_reduce":
  144. name = name[1:]
  145. _builtin_ops.append((v, "aten::" + name))
  146. for mod in _modules_containing_builtins:
  147. register_all(mod)
  148. _builtin_ops.append((math.gcd, "aten::gcd"))
  149. _builtin_ops.append((math.isfinite, "aten::isfinite"))
  150. _builtin_ops.append((math.remainder, "aten::mathremainder")) # type: ignore[attr-defined]
  151. import torch.distributed.autograd as dist_autograd
  152. if dist_autograd.is_available():
  153. _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
  154. _builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))
  155. # populate the _builtin_table from _builtin_ops
  156. for builtin, aten_op in _builtin_ops:
  157. _builtin_table[id(builtin)] = aten_op
  158. return _builtin_table
  159. def _register_builtin(fn, op):
  160. _get_builtin_table()[id(fn)] = op
  161. def _find_builtin(fn):
  162. return _get_builtin_table().get(id(fn))