| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # mypy: allow-untyped-defs
- import contextlib
- import torch
- # Common testing utilities for use in public testing APIs.
- # NB: these should all be importable without optional dependencies
- # (like numpy and expecttest).
- def wrapper_set_seed(op, *args, **kwargs):
- """Wrapper to set seed manually for some functions like dropout
- See: https://github.com/pytorch/pytorch/pull/62315#issuecomment-896143189 for more details.
- """
- with freeze_rng_state():
- torch.manual_seed(42)
- output = op(*args, **kwargs)
- if isinstance(output, torch.Tensor) and output.device.type == "lazy":
- # We need to call mark step inside freeze_rng_state so that numerics
- # match eager execution
- torch._lazy.mark_step() # type: ignore[attr-defined]
- return output
- @contextlib.contextmanager
- def freeze_rng_state():
- # no_dispatch needed for test_composite_compliance
- # Some OpInfos use freeze_rng_state for rng determinism, but
- # test_composite_compliance overrides dispatch for all torch functions
- # which we need to disable to get and set rng state
- with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
- rng_state = torch.get_rng_state()
- if torch.cuda.is_available():
- cuda_rng_state = torch.cuda.get_rng_state()
- try:
- yield
- finally:
- # Modes are not happy with torch.cuda.set_rng_state
- # because it clones the state (which could produce a Tensor Subclass)
- # and then grabs the new tensor's data pointer in generator.set_state.
- #
- # In the long run torch.cuda.set_rng_state should probably be
- # an operator.
- #
- # NB: Mode disable is to avoid running cross-ref tests on thes seeding
- with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
- if torch.cuda.is_available():
- torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
- torch.set_rng_state(rng_state)
|