_utils.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import torch
  4. # Common testing utilities for use in public testing APIs.
  5. # NB: these should all be importable without optional dependencies
  6. # (like numpy and expecttest).
  7. def wrapper_set_seed(op, *args, **kwargs):
  8. """Wrapper to set seed manually for some functions like dropout
  9. See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
  10. """
  11. with freeze_rng_state():
  12. torch.manual_seed(42)
  13. output = op(*args, **kwargs)
  14. if isinstance(output, torch.Tensor) and output.device.type == "lazy":
  15. # We need to call mark step inside freeze_rng_state so that numerics
  16. # match eager execution
  17. torch._lazy.mark_step() # type: ignore[attr-defined]
  18. return output
  19. @contextlib.contextmanager
  20. def freeze_rng_state():
  21. # no_dispatch needed for test_composite_compliance
  22. # Some OpInfos use freeze_rng_state for rng determinism, but
  23. # test_composite_compliance overrides dispatch for all torch functions
  24. # which we need to disable to get and set rng state
  25. with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
  26. rng_state = torch.get_rng_state()
  27. if torch.cuda.is_available():
  28. cuda_rng_state = torch.cuda.get_rng_state()
  29. try:
  30. yield
  31. finally:
  32. # Modes are not happy with torch.cuda.set_rng_state
  33. # because it clones the state (which could produce a Tensor Subclass)
  34. # and then grabs the new tensor's data pointer in generator.set_state.
  35. #
  36. # In the long run torch.cuda.set_rng_state should probably be
  37. # an operator.
  38. #
  39. # NB: Mode disable is to avoid running cross-ref tests on thes seeding
  40. with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
  41. if torch.cuda.is_available():
  42. torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
  43. torch.set_rng_state(rng_state)