python.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import unittest.mock
  4. from contextlib import contextmanager
  5. from typing import Iterator
  6. import torch
  7. import torch._C
  8. import torch._ops
  9. import torch.utils._python_dispatch
  10. import torch.utils._pytree as pytree
  11. __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
  12. no_python_dispatcher = torch._C._DisablePythonDispatcher
  13. enable_python_dispatcher = torch._C._EnablePythonDispatcher
  14. enable_pre_dispatch = torch._C._EnablePreDispatch
  15. CROSSREF_FUNCTIONALIZE = False
  16. def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
  17. """
  18. Warning: the set of overloads this will report is very subtle. It is precisely
  19. the set of torch.ops functions that have actually been accessed from Python
  20. (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
  21. from the set of registered operators, which will in general be a larger set,
  22. as this would include all operators which we ran C++ static initializers or
  23. Python operator registration on. This does not eagerly populate the list on
  24. torch.ops.aten; this list is lazy!
  25. In other words, this is good for traversing over everything that has an
  26. OpOverload object allocated in Python. We use it for cache invalidation, but
  27. don't rely on this list being complete.
  28. Note that even if we did report all C++ registered overloads, this isn't guaranteed
  29. to be complete either, as a subsequent lazy load of a library which triggers more
  30. registrations could add more things to the set.
  31. """
  32. for ns in torch.ops:
  33. packets = getattr(torch.ops, ns)
  34. for op_name in packets:
  35. packet = getattr(packets, op_name)
  36. for overload in packet:
  37. yield getattr(packet, overload)
  38. @contextmanager
  39. def suspend_functionalization():
  40. f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
  41. torch._C.DispatchKey.Functionalize
  42. )
  43. f_rv = torch._C._functionalization_reapply_views_tls()
  44. if f_tls:
  45. torch._disable_functionalization()
  46. try:
  47. yield
  48. finally:
  49. if f_tls:
  50. torch._enable_functionalization(reapply_views=f_rv)
  51. def check_tensor_metadata_matches(nv, rv, desc):
  52. assert callable(desc)
  53. assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
  54. assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
  55. same_strides, idx = torch._prims_common.check_significant_strides(
  56. nv, rv, only_cuda=False
  57. )
  58. assert (
  59. same_strides
  60. ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
  61. def check_metadata_matches(n, r, desc):
  62. assert callable(desc)
  63. n_vals, n_spec = pytree.tree_flatten(n)
  64. r_vals, r_spec = pytree.tree_flatten(r)
  65. # TODO: test the specs match; empirically sometimes we have a tuple
  66. # on one side and a list on the other
  67. assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
  68. for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
  69. if not isinstance(rv, torch.Tensor):
  70. continue
  71. check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
  72. class Lit:
  73. def __init__(self, s):
  74. self.s = s
  75. def __repr__(self):
  76. return self.s
  77. def _fmt(a: object) -> object:
  78. if isinstance(a, torch.Tensor):
  79. return Lit(
  80. f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
  81. )
  82. else:
  83. return a
  84. def make_crossref_functionalize(op, final_key):
  85. from torch._subclasses.fake_tensor import FakeTensorMode
  86. # This case is pretty weird, suppress it for now
  87. if op == torch.ops.aten.lift_fresh.default:
  88. return final_key
  89. def handler(*args, **kwargs):
  90. fake_mode = FakeTensorMode()
  91. def fakeify_defun(t):
  92. if isinstance(t, torch.Tensor):
  93. if torch._is_functional_tensor(t):
  94. r = torch._from_functional_tensor(t)
  95. # NB: This assumes that the inner tensor sizes/strides match
  96. # the outer tensor sizes/strides. This doesn't necessarily have to
  97. # be the case, see discussion at
  98. # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
  99. assert t.size() == r.size()
  100. assert t.stride() == r.stride()
  101. else:
  102. r = t
  103. # TODO: suppress guards
  104. return fake_mode.from_tensor(r)
  105. return t
  106. def maybe_detach(t):
  107. if isinstance(t, torch.Tensor):
  108. return t.detach()
  109. else:
  110. return t
  111. # TODO: This probably does the wrong thing if you're running other
  112. # substantive modes with the normal op outside here
  113. with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
  114. f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
  115. orig_f_args, orig_f_kwargs = pytree.tree_map(
  116. maybe_detach, (f_args, f_kwargs)
  117. )
  118. with fake_mode:
  119. f_r = op(*f_args, **f_kwargs)
  120. r = op._op_dk(final_key, *args, **kwargs)
  121. def desc():
  122. fmt_args = ", ".join(
  123. itertools.chain(
  124. (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
  125. (
  126. f"{k}={pytree.tree_map(_fmt, v)}"
  127. for k, v in orig_f_kwargs.items()
  128. ),
  129. )
  130. )
  131. return f"{op}({fmt_args})"
  132. check_metadata_matches(f_r, r, desc)
  133. return r
  134. return handler
  135. # NB: enabling this is slow, don't do it in a hot loop. This is purely
  136. # for debugging purposes.
  137. @contextmanager
  138. def enable_crossref_functionalize():
  139. for op in all_py_loaded_overloads():
  140. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
  141. try:
  142. with enable_python_dispatcher(), unittest.mock.patch(
  143. "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
  144. ):
  145. yield
  146. finally:
  147. for op in all_py_loaded_overloads():
  148. op._uncache_dispatch(torch._C.DispatchKey.Functionalize)