debug_prims.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from typing import Optional
  4. import torch
  5. from torch.utils._content_store import ContentStoreReader
  6. LOAD_TENSOR_READER: Optional[ContentStoreReader] = None
  7. @contextlib.contextmanager
  8. def load_tensor_reader(loc):
  9. global LOAD_TENSOR_READER
  10. assert LOAD_TENSOR_READER is None
  11. # load_tensor is an "op", and we will play merry hell on
  12. # Inductor's memory planning if we return a tensor that
  13. # aliases another tensor that we previously returned from
  14. # an operator. So unlike standard ContentStoreReader use,
  15. # we disable the cache so that you always get fresh storages
  16. # (no aliasing for you!)
  17. LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
  18. try:
  19. yield
  20. finally:
  21. LOAD_TENSOR_READER = None
  22. def register_debug_prims():
  23. torch.library.define(
  24. "debugprims::load_tensor",
  25. "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor",
  26. )
  27. @torch.library.impl("debugprims::load_tensor", "BackendSelect")
  28. def load_tensor_factory(name, size, stride, dtype, device):
  29. if LOAD_TENSOR_READER is None:
  30. from torch._dynamo.testing import rand_strided
  31. return rand_strided(size, stride, dtype, device)
  32. else:
  33. from torch._dynamo.utils import clone_input
  34. # device argument here takes care of coercion
  35. r = LOAD_TENSOR_READER.read_tensor(name, device=device)
  36. assert list(r.size()) == size, f"{r.size()} != {size}"
  37. assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
  38. assert r.device == device, f"{r.device} != {device}"
  39. # Unlike the other properties, we will do coercions for dtype
  40. # mismatch
  41. if r.dtype != dtype:
  42. r = clone_input(r, dtype=dtype)
  43. return r