_ops.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import ctypes
  4. import importlib
  5. import inspect
  6. import sys
  7. import types
  8. from typing import Any, Callable, Dict, List, Set, Type, Union
  9. import torch._C
  10. import torch.utils._pytree as pytree
  11. from torch import _utils_internal
  12. from torch._functorch.pyfunctorch import dispatch_functorch
  13. from torch.utils._python_dispatch import TorchDispatchMode
  14. # Query `hasattr` only once.
  15. _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
  16. @contextlib.contextmanager
  17. def dl_open_guard():
  18. """
  19. Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
  20. shared library to load custom operators.
  21. """
  22. if not _SET_GLOBAL_FLAGS:
  23. yield
  24. return
  25. old_flags = sys.getdlopenflags()
  26. sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
  27. try:
  28. yield
  29. finally:
  30. sys.setdlopenflags(old_flags)
  31. class OperatorBase:
  32. """
  33. Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator
  34. (which represents Python-only operators that are unrepresentable in TorchScript).
  35. """
  36. def __init__(self):
  37. # The dispatch cache precomputes a mapping of dispatch key that the
  38. # dispatcher wants to dispatch to, to an actual implementation of the
  39. # dispatch key. Confusingly, the actual implementation could *also* be a
  40. # dispatch key, but in this case, this refers to the C++ kernel that
  41. # was registered to some dispatch key. Aliases are permitted in the
  42. # latter but not the former; for example, you might lookup the
  43. # entry for AutogradCPU, and this maps you to the Autograd key for
  44. # the generic autograd kernel that works for all devices. Since this
  45. # is the Python dispatcher, you can also put an arbitrary Python
  46. # callable to call instead. This handler gets precisely the
  47. # args/kwargs that the operator was __call__'ed with.
  48. # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp
  49. # for use with OpOverload; cache lookup is done entirely from C++
  50. # for speed.
  51. # TODO: The cache is NOT currently used by HigherOrderOperator, but it should!
  52. self._dispatch_cache: Dict[
  53. torch._C.DispatchKey, Union[torch._C.DispatchKey, Callable[..., Any]]
  54. ] = {}
  55. # This table allows you to override the behavior of a particular
  56. # dispatch key to call a custom Python function, rather than the
  57. # ordinary C++ configured behavior. This is the raison d'etre of
  58. # Python dispatcher: to let you program the dispatcher from Python
  59. # in case you need something unusual, and don't want to clobber
  60. # the existing registrations using the Python operator registration
  61. # API.
  62. self.py_kernels: Dict[torch._C.DispatchKey, Callable[..., Any]] = {}
  63. # This table allows you to override the behavior of a particular
  64. # operator for a particular TorchDispatchMode. In practice,
  65. # we are using this mostly for ProxyTensorMode. Modes can be
  66. # thought of as an open world extension of dispatch keys, so it
  67. # makes sense that you should be able to register them, the same
  68. # way you can register dispatch keys.
  69. self.python_key_mode_table: Dict[
  70. Type[TorchDispatchMode], Callable[..., Any]
  71. ] = {}
  72. # This table allows you to override the behavior of functorch
  73. # transformations. NB: this currently only does something for
  74. # HigherOrderOperator
  75. self.functorch_table = {}
  76. def __call__(self, *args, **kwargs):
  77. raise NotImplementedError
  78. def has_kernel_for_dispatch_key(self, k):
  79. return k in self.py_kernels
  80. def has_kernel_for_any_dispatch_key(self, ks):
  81. for k in self.py_kernels:
  82. if not torch._C._dispatch_is_alias_key(k) and ks.has(k):
  83. return True
  84. return False
  85. def py_impl(self, k):
  86. def inner(fn):
  87. if inspect.isclass(k) and issubclass(k, TorchDispatchMode):
  88. assert k not in self.python_key_mode_table
  89. # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
  90. self.python_key_mode_table[k] = fn
  91. self._dispatch_cache.clear()
  92. return fn
  93. if isinstance(k, torch._C._functorch.TransformType):
  94. assert k not in self.functorch_table
  95. self.functorch_table[k] = fn
  96. return fn
  97. assert isinstance(k, torch._C.DispatchKey)
  98. assert (
  99. k != torch._C.DispatchKey.Python
  100. ), "Please register a mode for the torch._C.DispatchKey.Python key instead."
  101. if k in self.py_kernels:
  102. raise RuntimeError(
  103. f"Trying to override a python impl for {k} on operator {self.name()}"
  104. )
  105. self.py_kernels[k] = fn
  106. self._dispatch_cache.clear()
  107. return fn
  108. return inner
  109. # Registers an implementation to all **3** variants of functionalization that we have:
  110. # - DispatchKey.Functionalize
  111. # - functorch.TransformType.Functionalize
  112. # - FunctionalTensorMode
  113. # Example:
  114. # @py_functionalize_impl
  115. # def functionalize_rule(ctx, inner_f, *args):
  116. # args_unwrapped = ctx.unwrap_tensors(args)
  117. # with ctx.redispatch_to_next():
  118. # out = ctx.functionalize(inner_f)(*args_unwrapped)
  119. # return ctx.wrap_tensors(out)
  120. def py_functionalize_impl(self, fn):
  121. from torch._subclasses.functional_tensor import (
  122. CppFunctionalizeAPI as _CppFunctionalizeAPI,
  123. FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI,
  124. PythonFunctionalizeAPI as _PythonFunctionalizeAPI,
  125. )
  126. # Construct our three flavors of functionalization,
  127. # each of which have slightly different wrap/unwrap/redispatch policies
  128. def functionalize_dk_fn(*args, **kwargs):
  129. return fn(_CppFunctionalizeAPI(), *args, **kwargs)
  130. def functionalize_dispatch_mode_fn(mode, *args, **kwargs):
  131. return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs)
  132. def functionalize_functorch_fn(interpreter, *args, **kwargs):
  133. return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs)
  134. self.py_impl(torch._C.DispatchKey.Functionalize)(functionalize_dk_fn)
  135. self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)(
  136. functionalize_dispatch_mode_fn
  137. )
  138. self.py_impl(torch._C._functorch.TransformType.Functionalize)(
  139. functionalize_functorch_fn
  140. )
  141. return fn
  142. def name(self):
  143. raise NotImplementedError
  144. is_included_in_alias = torch._C._dispatch_is_included_in_alias
  145. DispatchKey = torch._C.DispatchKey
  146. # Equivalent to computeDispatchTableEntryWithDebug
  147. def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type]
  148. # 1. (Direct) operator registration
  149. if op.has_kernel_for_dispatch_key(k):
  150. return k
  151. # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
  152. cand = DispatchKey.CompositeExplicitAutogradNonFunctional
  153. if (
  154. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  155. ) and op.has_kernel_for_dispatch_key(cand):
  156. return cand
  157. # 2.2 Use CompositeExplicitAutograd kernel if available
  158. cand = DispatchKey.CompositeExplicitAutograd
  159. if (
  160. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  161. ) and op.has_kernel_for_dispatch_key(cand):
  162. return cand
  163. has_backend_kernel = op.has_kernel_for_any_dispatch_key(
  164. torch._C._dispatch_get_backend_keyset_from_autograd(k)
  165. ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd)
  166. # 2.3. Use CompositeImplicitAutograd kernel if available
  167. cand = DispatchKey.CompositeImplicitAutogradNestedTensor
  168. if (
  169. (k != DispatchKey.Undefined and is_included_in_alias(k, cand))
  170. and op.has_kernel_for_dispatch_key(cand)
  171. and not has_backend_kernel
  172. ):
  173. return cand
  174. cand = DispatchKey.CompositeImplicitAutograd
  175. if (
  176. k == DispatchKey.Undefined or is_included_in_alias(k, cand)
  177. ) and op.has_kernel_for_dispatch_key(cand):
  178. if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key(
  179. torch._C._dispatch_autogradother_backends
  180. ):
  181. raise RuntimeError("ambiguous autogradother kernel")
  182. elif not has_backend_kernel:
  183. return cand
  184. # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
  185. cand = DispatchKey.Autograd
  186. if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
  187. return cand
  188. # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available
  189. cand = DispatchKey.FuncTorchBatchedDecomposition
  190. if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand):
  191. return cand
  192. # Backend fallback
  193. if torch._C._dispatch_has_backend_fallback(k):
  194. # The dispatch key itself will implicitly route to backend fallback.
  195. # This is probably not great for the pure Python implementation.
  196. return k
  197. raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
  198. _higher_order_ops: Dict[str, "HigherOrderOperator"] = {}
  199. _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [
  200. DispatchKey.PythonDispatcher, # type: ignore[attr-defined]
  201. DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined]
  202. DispatchKey.ADInplaceOrView,
  203. DispatchKey.BackendSelect,
  204. DispatchKey.AutocastCPU, # type: ignore[attr-defined]
  205. DispatchKey.AutocastCUDA, # type: ignore[attr-defined]
  206. ]
  207. class HigherOrderOperator(OperatorBase):
  208. # The HigherOrderOperator will appear as torch.ops.higher_order.{name}
  209. #
  210. # If you're creating a new HigherOrderOperator, please do not change the
  211. # default. Adding operators to the global torch.ops namespace is a bad
  212. # practice due to name collisions.
  213. def __init__(self, name):
  214. super().__init__()
  215. self._name = name
  216. # Make _OPNamespace not scream, this whole name based association needs a good hard look
  217. self.__name__ = name
  218. _higher_order_ops[name] = self
  219. self._ns = "higher_order"
  220. # For a normal HigherOrderOperator instance, we will change its __module__ from torch._ops to
  221. # torch._ops.higher_order.
  222. # For an instance of subclass of HigherOrderOperator (e.g. customized higher order op),
  223. # the __module__ attribute will be kept unchanged.
  224. if self.__class__ is HigherOrderOperator:
  225. self_name_space = "." + self.namespace if self.namespace else ""
  226. self.__module__ = self.__module__ + self_name_space
  227. self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
  228. for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
  229. self.fallthrough(dispatch_key)
  230. # [NOTE] We have to register pre-dispatch key implementation
  231. # because sometimes HOP use aot-dispatch tracing to detect certaion
  232. # mutations. This is problematic when we are functionalizing HOP
  233. # during pre-dispatch because when the inner tracer starts, it will see
  234. # that PreDispatch key is still active. In that case, we just redispatch
  235. # it to next key. This is only safe to do when PreDispatch key stack has no
  236. # active modes.
  237. def py_impl(self, k):
  238. if isinstance(k, torch._C.DispatchKey) and not self.non_fallthrough_keys.has(k):
  239. self.non_fallthrough_keys = self.non_fallthrough_keys.add(k)
  240. return super().py_impl(k)
  241. @property
  242. def namespace(self):
  243. return self._ns
  244. def fallthrough(self, dispatch_key):
  245. self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key)
  246. def dispatch(self, dispatch_key, *args, **kwargs):
  247. from torch.utils._python_dispatch import _get_current_dispatch_mode
  248. if dispatch_key in self._dispatch_cache:
  249. kernel = self._dispatch_cache[dispatch_key]
  250. assert not isinstance(kernel, torch._C.DispatchKey)
  251. return kernel(*args, **kwargs)
  252. if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode:
  253. return dispatch_functorch(self, args, kwargs)
  254. if dispatch_key == torch._C.DispatchKey.Python:
  255. # The place to handle ProxyTorchDispatchMode, FakeTensorMode, etc
  256. from torch.utils._python_dispatch import _pop_mode_temporarily
  257. curr_mode = _get_current_dispatch_mode()
  258. assert (
  259. curr_mode is not None
  260. ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
  261. assert (
  262. type(curr_mode) in self.python_key_mode_table
  263. ), f"Current active mode {curr_mode} not registered"
  264. handler = self.python_key_mode_table[type(curr_mode)]
  265. with _pop_mode_temporarily() as mode:
  266. return handler(mode, *args, **kwargs)
  267. functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined]
  268. if functionality_key == torch._C.DispatchKey.PreDispatch:
  269. from torch.utils._python_dispatch import _pop_mode_temporarily
  270. # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
  271. # calls inside of a mode.
  272. if (
  273. _len_torch_dispatch_stack_pre_dispatch() > 0
  274. ) and not torch._C._dispatch_tls_is_dispatch_key_excluded(
  275. DispatchKey.Python
  276. ):
  277. curr_mode = _get_current_dispatch_mode_pre_dispatch()
  278. assert (
  279. curr_mode is not None
  280. ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode."
  281. assert (
  282. type(curr_mode) in self.python_key_mode_table
  283. ), f"Current active mode {curr_mode} not registered"
  284. handler = self.python_key_mode_table[type(curr_mode)]
  285. with _pop_mode_temporarily(functionality_key) as mode:
  286. return handler(mode, *args, **kwargs)
  287. final_key = resolve_key(self, dispatch_key)
  288. # This can current fail due to backend fallbacks. You just have to
  289. # register them by hand for HigherOrderOperator.
  290. if final_key not in self.py_kernels:
  291. raise NotImplementedError(
  292. f"could not find kernel for HigherOrderOperator {self._name} "
  293. f"at dispatch key {final_key} (resolved from {dispatch_key})"
  294. )
  295. # [NOTE] We shouldn't cache PreDispatch kernel here because depending
  296. # on what modes are active, predispatch behaviour is different.
  297. # Also we do same thing for normal ops:
  298. # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
  299. if dispatch_key != torch._C.DispatchKey.PreDispatch:
  300. self._dispatch_cache[dispatch_key] = self.py_kernels[final_key]
  301. kernel = self.py_kernels[final_key]
  302. # It's illegal to register DispatchKey to py_kernels, since there's no
  303. # C++ kernel to call into
  304. assert not isinstance(kernel, torch._C.DispatchKey)
  305. return kernel(*args, **kwargs)
  306. def __call__(self, *args, **kwargs):
  307. # Dynamo already traces the body of HigherOrderOp beforehand when it
  308. # so no need to trace into it.
  309. import torch._dynamo
  310. from torch._dynamo import disable
  311. @disable
  312. def wrapper():
  313. flat_args = _to_flat_tuple(args, kwargs)
  314. if torch.overrides.has_torch_function(flat_args):
  315. return torch.overrides.handle_torch_function(
  316. self, flat_args, *args, **kwargs
  317. )
  318. dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
  319. return self.dispatch(
  320. dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
  321. )
  322. return wrapper()
  323. def __str__(self):
  324. return f"{self.name()}"
  325. def name(self):
  326. return self._name
  327. def _to_flat_tuple(args, kwargs):
  328. return pytree.arg_tree_leaves(*args, **kwargs)
  329. def _compute_keyset(args, kwargs, non_fallthrough_keys):
  330. tensors = _get_tensors(args, kwargs)
  331. return key_extractor(tensors, non_fallthrough_keys)
  332. def _get_tensors(args, kwargs):
  333. flat_all = _to_flat_tuple(args, kwargs)
  334. tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
  335. return tuple(tensor_args)
  336. # Note - this should maintain identical impl to the C++ dispatcher key extraction logic
  337. # at ATen/core/dispatch/DispatchKeyExtractor.h
  338. def key_extractor(tensors, key_mask):
  339. key_set = torch._C._dispatch_tls_local_include_set()
  340. for tensor in tensors:
  341. key_set = key_set | torch._C._dispatch_keys(tensor)
  342. key_set = key_set - torch._C._dispatch_tls_local_exclude_set()
  343. key_set = key_set & key_mask
  344. return key_set
  345. # Mode stack for PreDispatchKey
  346. # it should always have three keys with
  347. # priority given to FunctionalTensorMode and
  348. # then ProxyTorchDispatchMode. It means that
  349. # slot 0 belongs to ProxyTorchDispatchMode and
  350. # slot 1 belongs to FunctionalTensorMode.
  351. #
  352. # SchemaCheckMode is separate from the other 2,
  353. # and is only valid when the stack is empty.
  354. # SchemaCheckMode is for testing purposes, and
  355. # is meant to run in eager mode on concrete inputs,
  356. # checking for incorrect schemas in regards to
  357. # aliasing or mutating ops.
  358. class _ModeStackStateForPreDispatch:
  359. def __init__(self):
  360. self.__infra_modes = [None, None]
  361. self._schema_check_mode = None
  362. def set(self, index, mode):
  363. assert index < len(self.__infra_modes)
  364. self.__infra_modes[index] = mode
  365. def get(self, index):
  366. assert index < len(self.__infra_modes)
  367. return self.__infra_modes[index]
  368. def count(self):
  369. return len([i for i in self.__infra_modes if i is not None]) + int(
  370. self._schema_check_mode is not None
  371. )
  372. _mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch()
  373. def unset_mode_pre_dispatch(mode_key, schema_check=False):
  374. current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch()
  375. assert mode_key is None or mode_key in (
  376. torch._C._TorchDispatchModeKey.PROXY,
  377. torch._C._TorchDispatchModeKey.FUNCTIONAL,
  378. )
  379. if schema_check:
  380. assert mode_key is None
  381. def _unset_mode():
  382. if mode_key == torch._C._TorchDispatchModeKey.PROXY:
  383. current_mode = current_mode_stack_pre_dispatch.get(0)
  384. mode_stack_state_for_pre_dispatch().set(0, None)
  385. return current_mode
  386. elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL:
  387. current_mode = current_mode_stack_pre_dispatch.get(1)
  388. mode_stack_state_for_pre_dispatch().set(1, None)
  389. return current_mode
  390. else:
  391. current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
  392. mode_stack_state_for_pre_dispatch()._schema_check_mode = None
  393. return current_mode
  394. current_mode = _unset_mode()
  395. new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
  396. # When we are unsetting a mode, we need to check if there is
  397. # active mode left on the PreDispatch key. If there is nothing
  398. # active, we need to remove PreDispatch key from local dispatch include
  399. # set.
  400. if new_pre_dispatch_len == 0:
  401. torch._C._dispatch_tls_set_dispatch_key_included(
  402. torch._C.DispatchKey.PreDispatch, False
  403. )
  404. return current_mode
  405. def _set_mode_pre_dispatch(mode):
  406. from torch._subclasses.functional_tensor import FunctionalTensorMode
  407. from torch._subclasses.schema_check_mode import SchemaCheckMode
  408. from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
  409. assert isinstance(
  410. mode,
  411. (
  412. FunctionalTensorMode,
  413. ProxyTorchDispatchMode,
  414. SchemaCheckMode,
  415. ),
  416. )
  417. previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch()
  418. if isinstance(mode, SchemaCheckMode):
  419. current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode
  420. if previous_mode_stack_len > 0:
  421. raise AssertionError(
  422. "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack"
  423. )
  424. mode_stack_state_for_pre_dispatch()._schema_check_mode = mode
  425. elif isinstance(mode, FunctionalTensorMode):
  426. current_mode = mode_stack_state_for_pre_dispatch().get(1)
  427. assert current_mode is None
  428. mode_stack_state_for_pre_dispatch().set(1, mode)
  429. else:
  430. current_mode = mode_stack_state_for_pre_dispatch().get(0)
  431. assert current_mode is None
  432. mode_stack_state_for_pre_dispatch().set(0, mode)
  433. # When we are setting a mode, we need to check if there is
  434. # active mode left on the PreDispatch key. If there was nothing
  435. # active before setting this mode, it means that PreDispatch key
  436. # was turned off. So we need to turn it on again.
  437. if previous_mode_stack_len == 0:
  438. torch._C._dispatch_tls_set_dispatch_key_included(
  439. torch._C.DispatchKey.PreDispatch, True
  440. )
  441. def _pop_mode_from_pre_dispatch():
  442. mode_stack = mode_stack_state_for_pre_dispatch()
  443. pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch()
  444. if pre_dispatch_len == 0:
  445. raise AssertionError("Trying to pop empty mode stack")
  446. if mode_stack._schema_check_mode is not None:
  447. return unset_mode_pre_dispatch(None, schema_check=True)
  448. if mode_stack.get(1) is not None:
  449. return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL)
  450. if mode_stack.get(0) is not None:
  451. return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
  452. def _len_torch_dispatch_stack_pre_dispatch():
  453. return mode_stack_state_for_pre_dispatch().count()
  454. def _get_dispatch_mode_pre_dispatch(mode_key):
  455. assert mode_key in (
  456. torch._C._TorchDispatchModeKey.PROXY,
  457. torch._C._TorchDispatchModeKey.FUNCTIONAL,
  458. )
  459. if mode_key == torch._C._TorchDispatchModeKey.PROXY:
  460. return mode_stack_state_for_pre_dispatch().get(0)
  461. else:
  462. return mode_stack_state_for_pre_dispatch().get(1)
  463. def _get_current_dispatch_mode_pre_dispatch():
  464. if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None:
  465. return mode_stack_state_for_pre_dispatch()._schema_check_mode
  466. else:
  467. stack_len = mode_stack_state_for_pre_dispatch().count()
  468. if stack_len == 2:
  469. return mode_stack_state_for_pre_dispatch().get(1)
  470. if stack_len == 1:
  471. return (
  472. mode_stack_state_for_pre_dispatch().get(1)
  473. if mode_stack_state_for_pre_dispatch().get(1) is not None
  474. else mode_stack_state_for_pre_dispatch().get(0)
  475. )
  476. return None
  477. def mode_stack_state_for_pre_dispatch():
  478. global _mode_stack_state_for_pre_dispatch
  479. return _mode_stack_state_for_pre_dispatch
  480. cached_ops: Set["OpOverload"] = set()
  481. def add_cached_op(op_overload):
  482. global cached_ops
  483. cached_ops.add(op_overload)
  484. def reset_cached_ops():
  485. global cached_ops
  486. cached_ops.clear()
  487. def get_cached_ops():
  488. global cached_ops
  489. return cached_ops
  490. # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
  491. # You can obtain an OpOverload object through attribute query on OpOverloadPacket.
  492. class OpOverload(OperatorBase):
  493. def __init__(self, overloadpacket, op, op_dk, schema, tags):
  494. super().__init__()
  495. self._op = op
  496. self._op_dk = op_dk
  497. self._schema = schema
  498. self._overloadpacket = overloadpacket
  499. self._tags = tags
  500. self._overloadname = (
  501. "default" if schema.overload_name == "" else schema.overload_name
  502. )
  503. self._name = self._schema.name
  504. if schema.overload_name:
  505. self._name += "." + schema.overload_name
  506. self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}"
  507. self.__module__ = overloadpacket.__module__
  508. op.__module__ = overloadpacket.__module__
  509. self.__qualname__ = self._name
  510. self.__annotations__ = {}
  511. # Only compute the OperatorHandle when we need it. Not all OpOverloads have
  512. # OperatorHandles (the TorchScript ones don't...)
  513. self._lazy_handle = None
  514. # If the OpOverload was constructed from a Library.def in Python.
  515. self._defined_in_python = self.__qualname__ in torch.library._defs
  516. # Logic replicated from aten/src/ATen/native/MathBitsFallback.h
  517. is_write = None
  518. for a in self._schema.arguments:
  519. if a.alias_info is None:
  520. continue
  521. if is_write is None:
  522. is_write = a.alias_info.is_write
  523. else:
  524. # We will conservatively call mixed mutable/non-mutable
  525. # aliased inputs as NOT a view
  526. is_write = a.alias_info.is_write or is_write
  527. self.is_view = is_write is not None and not is_write
  528. @property
  529. def _namespace(self):
  530. return self._schema.name.split("::")[0]
  531. @property
  532. def _opname(self):
  533. return self._schema.name.split("::")[1]
  534. @property
  535. def _handle(self):
  536. if self._lazy_handle is None:
  537. self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
  538. self._schema.name, self._schema.overload_name
  539. )
  540. return self._lazy_handle
  541. # it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
  542. def __deepcopy__(self, memo=None):
  543. return self
  544. def __repr__(self):
  545. return "<OpOverload(op='{}.{}', overload='{}')>".format(
  546. *self._schema.name.split("::"), self._overloadname
  547. )
  548. def __call__(self_, *args, **kwargs): # noqa: B902
  549. # use `self_` to avoid naming collide with aten ops arguments that
  550. # are named "self". This way, all the aten ops can be called by kwargs.
  551. return self_._op(*args, **kwargs)
  552. def redispatch(self_, keyset, *args, **kwargs): # noqa: B902
  553. # use `self_` to avoid naming collide with aten ops arguments that
  554. # are named "self". This way, all the aten ops can be called by kwargs.
  555. return self_._handle.redispatch_boxed(keyset, *args, **kwargs)
  556. def __hash__(self):
  557. return hash(self._op)
  558. # `my_namespace.my_op_name.overload_name`
  559. def __str__(self):
  560. return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
  561. def has_kernel_for_dispatch_key(self, k):
  562. return super().has_kernel_for_dispatch_key(
  563. k
  564. ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k)
  565. def has_kernel_for_any_dispatch_key(self, ks):
  566. return torch._C._dispatch_has_kernel_for_any_dispatch_key(
  567. self.name(), ks
  568. ) or super().has_kernel_for_any_dispatch_key(ks)
  569. @property
  570. def namespace(self):
  571. return self._schema.name.split("::")[0]
  572. def decompose(self, *args, **kwargs):
  573. dk = torch._C.DispatchKey.CompositeImplicitAutograd
  574. if dk in self.py_kernels:
  575. # NB: This branch is not too necessary anymore, because we can
  576. # apply Python CompositeImplicitAutograd *before* tracing
  577. # using Python dispatcher (also taking advantage of the autograd
  578. # formula). But it's included for completeness
  579. return self.py_kernels[dk](*args, **kwargs)
  580. elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
  581. return self._op_dk(dk, *args, **kwargs)
  582. else:
  583. return NotImplemented
  584. # Remove a dispatch key from the dispatch cache. This will force it to get
  585. # recomputed the next time. Does nothing
  586. # WARNING: if you register a dispatch key to py_kernels of an OpOverload,
  587. # calling _del_dispatch on that key is NOT sufficient to apply your change,
  588. # because a single registration may affect MULTIPLE dispatch keys (e.g.,
  589. # registering Autograd affects AutogradCPU). del_dispatch is to be used
  590. # only if you are specifically modifying how get_dispatch handles a
  591. # particular input 'key'.
  592. def _uncache_dispatch(self, key):
  593. self._dispatch_cache.pop(key, None)
  594. # This implements the pre-computation logic for the Python dispatcher.
  595. def _get_dispatch(self, key):
  596. # This is only called upon a cache miss
  597. assert key not in self._dispatch_cache, f"{self} {key}"
  598. if key == torch._C.DispatchKey.Python:
  599. if (
  600. not isinstance(self, TorchBindOpOverload)
  601. and not self.python_key_mode_table
  602. ):
  603. self._dispatch_cache[key] = key
  604. add_cached_op(self)
  605. return key
  606. def handler(*args, **kwargs):
  607. from torch.utils._python_dispatch import _get_current_dispatch_mode
  608. # TODO: We also need to handle tensor subclasses here
  609. # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
  610. curr_mode = type(_get_current_dispatch_mode())
  611. assert (
  612. curr_mode is not None
  613. ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
  614. if curr_mode not in self.python_key_mode_table:
  615. if isinstance(self, TorchBindOpOverload):
  616. with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
  617. return torch._library.utils.handle_dispatch_mode(
  618. mode, self, *args, **kwargs
  619. )
  620. else:
  621. return self._op_dk(key, *args, **kwargs)
  622. with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
  623. return self.python_key_mode_table[curr_mode](mode, *args, **kwargs)
  624. self._dispatch_cache[key] = handler
  625. add_cached_op(self)
  626. return handler
  627. functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined]
  628. if functionality_key == torch._C.DispatchKey.PreDispatch:
  629. curr_stack_len = _len_torch_dispatch_stack_pre_dispatch()
  630. # The check for Python in the exclude set is so we properly respect `with no_dispatch()`
  631. # calls inside of a mode.
  632. if (
  633. curr_stack_len > 0
  634. and not torch._C._dispatch_tls_is_dispatch_key_excluded(
  635. DispatchKey.Python
  636. )
  637. ):
  638. def handler(*args, **kwargs):
  639. @contextlib.contextmanager
  640. def _temporarily_pop_modes_from_pre_dispatch():
  641. top_mode = _pop_mode_from_pre_dispatch()
  642. try:
  643. yield top_mode
  644. finally:
  645. _set_mode_pre_dispatch(top_mode)
  646. with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
  647. return torch._library.utils.handle_dispatch_mode(
  648. curr_mode, self, *args, **kwargs
  649. )
  650. # Note [Not Caching Per-Dispatch-Key Mode Handlers]
  651. # Note that we're not caching this handler. There isn't really a point, since the slow bit
  652. # is the handler itself (in python).
  653. # Also, not caching means that we don't have to reset the cache when any existing
  654. # modes go out of scope (which in of itself takes time to loop through all operators).
  655. return handler
  656. final_key = resolve_key(self, key)
  657. # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
  658. cache_result = key != torch._C.DispatchKey.PreDispatch
  659. # TODO: We could potentially have lots of debugging wrappers against
  660. # dispatch keys; design some general registration mechanism instead of
  661. # having if statement for each of them
  662. if key == torch._C.DispatchKey.Functionalize:
  663. import torch._dispatch.python as pydispatch
  664. if pydispatch.CROSSREF_FUNCTIONALIZE:
  665. handler = pydispatch.make_crossref_functionalize(self, final_key)
  666. if cache_result:
  667. self._dispatch_cache[key] = handler
  668. add_cached_op(self)
  669. return handler
  670. r = self.py_kernels.get(final_key, final_key)
  671. if cache_result:
  672. self._dispatch_cache[key] = r
  673. add_cached_op(self)
  674. return r
  675. def name(self):
  676. return self._name
  677. @property
  678. def overloadpacket(self):
  679. return self._overloadpacket
  680. @property
  681. def op(self):
  682. return self._op
  683. @property
  684. def tags(self):
  685. return self._tags
  686. # TODO: add more methods to expose information about input and output arguments
  687. # TorchBindOpOverload are those custom ops which have at least one overload's
  688. # schema consists of torch.ScriptObject (i.e. custom class) input.
  689. # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
  690. # when its inputs contain FakeScriptObject in a similar way as higher order ops.
  691. class TorchBindOpOverload(OpOverload):
  692. def _fallthrough_keys(self) -> List[DispatchKey]:
  693. # TODO: we should be calling the fallback for these, but a fallthrough is almost close
  694. # enough to the fallback in most cases that we care about.
  695. _DEFAULT_FALLTHROUGH_KEYS = [
  696. DispatchKey.Autograd,
  697. DispatchKey.AutogradCPU,
  698. DispatchKey.AutogradCUDA,
  699. DispatchKey.ADInplaceOrView,
  700. DispatchKey.BackendSelect,
  701. DispatchKey.PythonTLSSnapshot,
  702. DispatchKey.PythonDispatcher,
  703. ]
  704. def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
  705. if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
  706. return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
  707. self.name(), key
  708. )
  709. return (
  710. key not in self.py_kernels
  711. or self.py_kernels[key] is torch.library.fallthrough_kernel
  712. )
  713. return [
  714. key
  715. for key in _DEFAULT_FALLTHROUGH_KEYS
  716. if _may_use_fallthrough_instead_of_fallback(key)
  717. ]
  718. @contextlib.contextmanager
  719. def _register_as_effectful_op_temporarily(self):
  720. from torch._higher_order_ops.effects import (
  721. _EffectType,
  722. _register_effectful_op,
  723. SIDE_EFFECTS,
  724. )
  725. try:
  726. if self not in SIDE_EFFECTS:
  727. _register_effectful_op(self, _EffectType.ORDERED)
  728. yield
  729. finally:
  730. if self in SIDE_EFFECTS:
  731. del SIDE_EFFECTS[self]
  732. # use `self_` to avoid naming collide with arguments that
  733. # are named "self". This way, they can be called by kwargs.
  734. def __call__(self_, *args, **kwargs): # noqa: B902
  735. if _must_dispatch_in_python(args, kwargs):
  736. # When any inputs are FakeScriptObject, we need to
  737. # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
  738. # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject.
  739. #
  740. # Note:
  741. # 1. We only register the torchbind op temporarily as effectful op because we only want
  742. # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior
  743. # of the eagerly executing the op might change after tracing.
  744. # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
  745. # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
  746. with self_._register_as_effectful_op_temporarily():
  747. return self_._dispatch_in_python(
  748. args, kwargs, self_._fallthrough_keys()
  749. )
  750. return self_._op(*args, **kwargs)
  751. def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
  752. non_fallthrough_keys = torch._C._dispatch_keyset_full()
  753. for key in fallthrough_keys:
  754. non_fallthrough_keys = non_fallthrough_keys.remove(key)
  755. dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
  756. dispatch_key = dispatch_key_set.highestPriorityTypeId()
  757. handler = (
  758. self._get_dispatch(dispatch_key)
  759. if dispatch_key not in self._dispatch_cache
  760. else self._dispatch_cache[dispatch_key]
  761. )
  762. if isinstance(handler, DispatchKey):
  763. # fallthrough keys can be registered at runtime via torch.library.impl
  764. # so need to add it to fallthrough_keys and re-dispatch.
  765. if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
  766. self.name(), dispatch_key
  767. ):
  768. return self._dispatch_in_python(
  769. args, kwargs, fallthrough_keys + [dispatch_key]
  770. )
  771. raise RuntimeError(
  772. f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}."
  773. f" but no python implementation is found."
  774. f" Please file an issue on this when you encounter this error."
  775. f" This error can happen when you export or compile the model."
  776. f" It can still happpen even if a C++ implementation for {dispatch_key}. "
  777. f" has been registered. That's because FakeScriptObject purely lives in python and cannot work "
  778. f" with a C++ implementation."
  779. )
  780. assert isinstance(handler, Callable) # type: ignore[arg-type]
  781. return handler(*args, **kwargs)
  782. def _must_dispatch_in_python(args, kwargs):
  783. return pytree.tree_any(
  784. lambda obj: isinstance(
  785. obj, torch._library.fake_class_registry.FakeScriptObject
  786. ),
  787. (args, kwargs),
  788. )
  789. def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
  790. return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
  791. # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
  792. # You can obtain an OpOverload object through attribute query.
  793. class OpOverloadPacket:
  794. def __init__(self, qualified_op_name, op_name, op, overload_names):
  795. # These attributes are accessible on the object through the properties
  796. # defined below but are immutable
  797. self._qualified_op_name = qualified_op_name
  798. self.__name__ = op_name
  799. self._op = op
  800. self._overload_names = overload_names
  801. self._dir = []
  802. self._has_torchbind_op_overload = any(
  803. _has_script_object_arg(schema) for schema in self._schemas.values()
  804. )
  805. # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
  806. def __deepcopy__(self, memo=None):
  807. return self
  808. def __repr__(self):
  809. return "<OpOverloadPacket(op='{}.{}')>".format(
  810. *self._qualified_op_name.split("::")
  811. )
  812. def __hash__(self):
  813. return hash(self._op)
  814. def __str__(self):
  815. return "{}.{}".format(*self._qualified_op_name.split("::"))
  816. @property
  817. def op(self):
  818. return self._op
  819. @property
  820. def _schemas(self):
  821. return {
  822. overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
  823. for overload_name in self._overload_names
  824. }
  825. def __getattr__(self, key):
  826. # It is not a valid op_name when __file__ is passed in
  827. if key == "__file__":
  828. return "torch.ops"
  829. # ensure that query for dunder attributes that does not exist on
  830. # opoverloadpacket but instead exists on the self._op object does not unnecessarily call
  831. # `_get_operation_overload` (which is an expensive operation).
  832. # This is done to prevent any potential slowdown. This list can be extended
  833. # if there exists other attributes like `__name__` that only exist on self._op and not on the
  834. # opoverloadpacket.
  835. # This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
  836. try:
  837. if key.startswith("__"):
  838. return getattr(self._op, key)
  839. except AttributeError:
  840. # for consistency because it seems weird to
  841. # throw an attribute error with a message containing
  842. # an object name different from the one the attribute
  843. # query was performed on.
  844. raise AttributeError(
  845. f"'{str(self)}' can't have an overload name beginning with '__' and the "
  846. f"underlying op {str(self._op)} has no attribute {key} either."
  847. ) from None
  848. try:
  849. # This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
  850. use_key = "" if key == "default" else key
  851. # TODO: disallow access to overloads registered by JIT
  852. op_, op_dk_, tags = torch._C._get_operation_overload(
  853. self._qualified_op_name, use_key
  854. )
  855. schema = torch._C._get_schema(self._qualified_op_name, use_key)
  856. overload = (
  857. OpOverload(self, op_, op_dk_, schema, tags)
  858. if not _has_script_object_arg(schema)
  859. else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
  860. )
  861. # cache the overload object
  862. setattr(self, key, overload)
  863. self._dir.append(key)
  864. return overload
  865. except RuntimeError:
  866. raise AttributeError(
  867. f"The underlying op of '{str(self)}' has no overload name '{key}'"
  868. ) from None
  869. def __iter__(self):
  870. return iter(self._dir)
  871. def __call__(self_, *args, **kwargs): # noqa: B902
  872. # use `self_` to avoid naming collide with aten ops arguments that
  873. # named "self". This way, all the aten ops can be called by kwargs.
  874. # overloading __call__ to ensure torch.ops.foo.bar()
  875. # is still callable from JIT
  876. # We save the function ptr as the `op` attribute on
  877. # OpOverloadPacket to access it here.
  878. # Directly calling OverloadPacket goes into C++, which will check
  879. # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
  880. # intercept it here and call TorchBindOpverload instead.
  881. if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
  882. return _call_overload_packet_from_python(self_, args, kwargs)
  883. return self_._op(*args, **(kwargs or {}))
  884. # TODO: use this to make a __dir__
  885. def overloads(self):
  886. return [n if n else "default" for n in self._overload_names]
  887. # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
  888. # _jit_get_operations, which calls _get_operation_for_overload_or_packet.
  889. def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
  890. # Re-use the torch function handling logic in cpp
  891. torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
  892. op, *args, **kwargs
  893. )
  894. if torch_function_called:
  895. return ret
  896. # The following mirrors getOpWithStack.
  897. # In cpp, we do a schema matching for the arguments, and call ToIValue to
  898. # to check whether the arguments are valid. But need to do similar things here
  899. # and check the schema whether the FakeScriptObject is the corresponding fake class
  900. # of the actual class used in schema.
  901. exceptions = {}
  902. found_op = None
  903. for overload_name in op.overloads():
  904. op_overload = getattr(op, overload_name)
  905. try:
  906. _ = torch._C._check_schema_allow_fake_script_object(
  907. op_overload._schema, *args, **kwargs
  908. )
  909. found_op = op_overload
  910. break
  911. except RuntimeError as e:
  912. exceptions[overload_name] = e
  913. if found_op:
  914. return found_op(*args, **kwargs)
  915. err_msg = (
  916. f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
  917. )
  918. for i, (key, msg) in enumerate(exceptions.items()):
  919. err_msg += f"Overload name {key}:\n {msg}\n"
  920. raise RuntimeError(err_msg)
  921. # Resolution of torch.fn is different from torch.ops.aten.fn
  922. # torch.fn uses the Python argparser, matches with the
  923. # appropriate schema, and calls into the unboxed version of the method
  924. # torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
  925. # JIT creates a stack of all the overloads and then tries to match the
  926. # correct one at runtime and always calls into the boxed version of the method
  927. # Autograd codegen creates VariableType, TracerType,
  928. # inplace or view type and python bindings.
  929. # Aten codegen generates tensor methods for the tensor class.
  930. # _OpNamespace is a subclass of ModuleType because the torch script
  931. # allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
  932. # to work from script, we need to ensure ops and foo are modules
  933. class _OpNamespace(types.ModuleType):
  934. """
  935. An op namespace to dynamically bind Operators into Python.
  936. Say a user has created a custom Operator called "my_namespace::my_op". To
  937. call this op, the user will write torch.ops.my_namespace.my_op(...).
  938. At startup, this operation will not yet be bound into Python. Instead, the
  939. following sequence of magic tricks will occur:
  940. 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method
  941. on the `torch.ops` object, which will create a new `_OpNamespace`
  942. object called `my_namespace` and set it as an attribute on the `ops`
  943. object.
  944. 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on
  945. the `my_namespace` object, which will retrieve the operation via
  946. `torch.get_operation`, a function bound from C++, and then in a similar
  947. fashion bind this new object onto the `my_namespace` object.
  948. 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation
  949. and subsequent accesses will incur no further lookup (the namespace and
  950. operation will already exist).
  951. """
  952. def __init__(self, name):
  953. super().__init__("torch.ops." + name)
  954. self.name = name
  955. self._dir = []
  956. def __iter__(self):
  957. return iter(self._dir)
  958. def __getattr__(self, op_name):
  959. # It is not a valid op_name when __file__ is passed in
  960. if op_name == "__file__":
  961. return "torch.ops"
  962. elif op_name in ["__origin__", "__self__"]:
  963. raise AttributeError(
  964. f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
  965. )
  966. # Get the op `my_namespace::my_op` if available. This will also check
  967. # for overloads and raise an exception if there are more than one.
  968. namespace_name = self.name
  969. qualified_op_name = f"{namespace_name}::{op_name}"
  970. module_name = self.__module__ + "." + namespace_name
  971. try:
  972. op, overload_names = _get_packet(qualified_op_name, module_name)
  973. if op is None:
  974. raise AttributeError(
  975. f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
  976. )
  977. except RuntimeError as e:
  978. # Turn this into AttributeError so getattr(obj, key, default)
  979. # works (this is called by TorchScript with __origin__)
  980. raise AttributeError(
  981. f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'"
  982. ) from e
  983. op.__module__ = module_name
  984. opoverloadpacket = OpOverloadPacket(
  985. qualified_op_name, op_name, op, overload_names
  986. )
  987. opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
  988. # cache the opoverloadpacket to ensure that each op corresponds to
  989. # a unique OpOverloadPacket object
  990. setattr(self, op_name, opoverloadpacket)
  991. self._dir.append(op_name)
  992. return opoverloadpacket
  993. def _get_packet(qualname, op_module):
  994. op, overload_names = torch._C._jit_get_operation(qualname)
  995. if op is not None:
  996. # let the script frontend know that op is identical to the builtin op
  997. # with qualified_op_name
  998. torch.jit._builtins._register_builtin(op, qualname)
  999. op.__module__ = op_module
  1000. return op, overload_names
  1001. def _refresh_packet(packet):
  1002. op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__)
  1003. assert op is not None
  1004. packet._op = op
  1005. packet._overload_names = overload_names
  1006. class _PyOpNamespace(_OpNamespace):
  1007. def __init__(self, name, ops):
  1008. super().__init__(name)
  1009. self._ops = ops
  1010. def __getattr__(self, name):
  1011. # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
  1012. op = self._ops.get(name, None)
  1013. if op is None:
  1014. raise AttributeError(
  1015. f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
  1016. )
  1017. setattr(self, name, op)
  1018. return op
  1019. class _Ops(types.ModuleType):
  1020. __file__ = "_ops.py"
  1021. def __init__(self):
  1022. super().__init__("torch.ops")
  1023. self.loaded_libraries = set()
  1024. self._higher_order_op_namespace = _PyOpNamespace(
  1025. "torch.ops.higher_order", _higher_order_ops
  1026. )
  1027. self._dir = []
  1028. def __getattr__(self, name):
  1029. # Check if the name is a HigherOrderOperator
  1030. if name == "higher_order":
  1031. return self._higher_order_op_namespace
  1032. # Here we are creating `torch.ops.my_namespace`
  1033. namespace = _OpNamespace(name)
  1034. setattr(self, name, namespace)
  1035. self._dir.append(name)
  1036. return namespace
  1037. def __iter__(self):
  1038. return iter(self._dir)
  1039. def import_module(self, module):
  1040. """
  1041. Imports a Python module that has torch.library registrations.
  1042. Generally, to extend PyTorch with custom operators, a user will
  1043. create a Python module whose import triggers registration of
  1044. the custom operators via a torch.ops.load_library call or a call
  1045. to one or more torch.library.* APIs.
  1046. It is unexpected for Python modules to have side effects, so some
  1047. linters and formatters will complain. Use this API to import Python
  1048. modules that contain these torch.library side effects.
  1049. Args:
  1050. module (str): The name of the Python module to import
  1051. """
  1052. importlib.import_module(module)
  1053. def load_library(self, path):
  1054. """
  1055. Loads a shared library from the given path into the current process.
  1056. The library being loaded may run global initialization code to register
  1057. custom operators with the PyTorch JIT runtime. This allows dynamically
  1058. loading custom operators. For this, you should compile your operator
  1059. and the static registration code into a shared library object, and then
  1060. call ``torch.ops.load_library('path/to/libcustom.so')`` to load the
  1061. shared object.
  1062. After the library is loaded, it is added to the
  1063. ``torch.ops.loaded_libraries`` attribute, a set that may be inspected
  1064. for the paths of all libraries loaded using this function.
  1065. Args:
  1066. path (str): A path to a shared library to load.
  1067. """
  1068. if torch._running_with_deploy():
  1069. return
  1070. path = _utils_internal.resolve_library_path(path)
  1071. with dl_open_guard():
  1072. # Import the shared library into the process, thus running its
  1073. # static (global) initialization code in order to register custom
  1074. # operators with the JIT.
  1075. ctypes.CDLL(path)
  1076. self.loaded_libraries.add(path)
  1077. # The ops "namespace"
  1078. ops: _Ops = _Ops()