| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # mypy: allow-untyped-defs
- import itertools
- import unittest.mock
- from contextlib import contextmanager
- from typing import Iterator
- import torch
- import torch._C
- import torch._ops
- import torch.utils._python_dispatch
- import torch.utils._pytree as pytree
- __all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
- no_python_dispatcher = torch._C._DisablePythonDispatcher
- enable_python_dispatcher = torch._C._EnablePythonDispatcher
- enable_pre_dispatch = torch._C._EnablePreDispatch
- CROSSREF_FUNCTIONALIZE = False
- def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]:
- """
- Warning: the set of overloads this will report is very subtle. It is precisely
- the set of torch.ops functions that have actually been accessed from Python
- (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT
- from the set of registered operators, which will in general be a larger set,
- as this would include all operators which we ran C++ static initializers or
- Python operator registration on. This does not eagerly populate the list on
- torch.ops.aten; this list is lazy!
- In other words, this is good for traversing over everything that has an
- OpOverload object allocated in Python. We use it for cache invalidation, but
- don't rely on this list being complete.
- Note that even if we did report all C++ registered overloads, this isn't guaranteed
- to be complete either, as a subsequent lazy load of a library which triggers more
- registrations could add more things to the set.
- """
- for ns in torch.ops:
- packets = getattr(torch.ops, ns)
- for op_name in packets:
- packet = getattr(packets, op_name)
- for overload in packet:
- yield getattr(packet, overload)
- @contextmanager
- def suspend_functionalization():
- f_tls = torch._C._dispatch_tls_is_dispatch_key_included(
- torch._C.DispatchKey.Functionalize
- )
- f_rv = torch._C._functionalization_reapply_views_tls()
- if f_tls:
- torch._disable_functionalization()
- try:
- yield
- finally:
- if f_tls:
- torch._enable_functionalization(reapply_views=f_rv)
- def check_tensor_metadata_matches(nv, rv, desc):
- assert callable(desc)
- assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}"
- assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}"
- same_strides, idx = torch._prims_common.check_significant_strides(
- nv, rv, only_cuda=False
- )
- assert (
- same_strides
- ), f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})"
- def check_metadata_matches(n, r, desc):
- assert callable(desc)
- n_vals, n_spec = pytree.tree_flatten(n)
- r_vals, r_spec = pytree.tree_flatten(r)
- # TODO: test the specs match; empirically sometimes we have a tuple
- # on one side and a list on the other
- assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
- for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
- if not isinstance(rv, torch.Tensor):
- continue
- check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}")
- class Lit:
- def __init__(self, s):
- self.s = s
- def __repr__(self):
- return self.s
- def _fmt(a: object) -> object:
- if isinstance(a, torch.Tensor):
- return Lit(
- f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})"
- )
- else:
- return a
- def make_crossref_functionalize(op, final_key):
- from torch._subclasses.fake_tensor import FakeTensorMode
- # This case is pretty weird, suppress it for now
- if op == torch.ops.aten.lift_fresh.default:
- return final_key
- def handler(*args, **kwargs):
- fake_mode = FakeTensorMode()
- def fakeify_defun(t):
- if isinstance(t, torch.Tensor):
- if torch._is_functional_tensor(t):
- r = torch._from_functional_tensor(t)
- # NB: This assumes that the inner tensor sizes/strides match
- # the outer tensor sizes/strides. This doesn't necessarily have to
- # be the case, see discussion at
- # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456
- assert t.size() == r.size()
- assert t.stride() == r.stride()
- else:
- r = t
- # TODO: suppress guards
- return fake_mode.from_tensor(r)
- return t
- def maybe_detach(t):
- if isinstance(t, torch.Tensor):
- return t.detach()
- else:
- return t
- # TODO: This probably does the wrong thing if you're running other
- # substantive modes with the normal op outside here
- with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
- f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
- orig_f_args, orig_f_kwargs = pytree.tree_map(
- maybe_detach, (f_args, f_kwargs)
- )
- with fake_mode:
- f_r = op(*f_args, **f_kwargs)
- r = op._op_dk(final_key, *args, **kwargs)
- def desc():
- fmt_args = ", ".join(
- itertools.chain(
- (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args),
- (
- f"{k}={pytree.tree_map(_fmt, v)}"
- for k, v in orig_f_kwargs.items()
- ),
- )
- )
- return f"{op}({fmt_args})"
- check_metadata_matches(f_r, r, desc)
- return r
- return handler
- # NB: enabling this is slow, don't do it in a hot loop. This is purely
- # for debugging purposes.
- @contextmanager
- def enable_crossref_functionalize():
- for op in all_py_loaded_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
- try:
- with enable_python_dispatcher(), unittest.mock.patch(
- "torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True
- ):
- yield
- finally:
- for op in all_py_loaded_overloads():
- op._uncache_dispatch(torch._C.DispatchKey.Functionalize)
|