| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- # mypy: allow-untyped-defs
- import contextlib
- from typing import Generator
- import warnings
- from torch._C import default_generator
- import torch
- def set_rng_state(new_state: torch.Tensor) -> None:
- r"""Sets the random number generator state.
- .. note:: This function only works for CPU. For CUDA, please use
- :func:`torch.manual_seed`, which works for both CPU and CUDA.
- Args:
- new_state (torch.ByteTensor): The desired state
- """
- default_generator.set_state(new_state)
- def get_rng_state() -> torch.Tensor:
- r"""Returns the random number generator state as a `torch.ByteTensor`.
- .. note:: The returned state is for the default generator on CPU only.
- See also: :func:`torch.random.fork_rng`.
- """
- return default_generator.get_state()
- def manual_seed(seed) -> torch._C.Generator:
- r"""Sets the seed for generating random numbers on all devices. Returns a
- `torch.Generator` object.
- Args:
- seed (int): The desired seed. Value must be within the inclusive range
- `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
- is raised. Negative inputs are remapped to positive values with the formula
- `0xffff_ffff_ffff_ffff + seed`.
- """
- seed = int(seed)
- import torch.cuda
- if not torch.cuda._is_in_bad_fork():
- torch.cuda.manual_seed_all(seed)
- import torch.mps
- if not torch.mps._is_in_bad_fork():
- torch.mps.manual_seed(seed)
- import torch.xpu
- if not torch.xpu._is_in_bad_fork():
- torch.xpu.manual_seed_all(seed)
- _seed_custom_device(seed)
- return default_generator.manual_seed(seed)
- def seed() -> int:
- r"""Sets the seed for generating random numbers to a non-deterministic
- random number on all devices. Returns a 64 bit number used to seed the RNG.
- """
- seed = default_generator.seed()
- import torch.cuda
- if not torch.cuda._is_in_bad_fork():
- torch.cuda.manual_seed_all(seed)
- import torch.mps
- if not torch.mps._is_in_bad_fork():
- torch.mps.manual_seed(seed)
- import torch.xpu
- if not torch.xpu._is_in_bad_fork():
- torch.xpu.manual_seed_all(seed)
- _seed_custom_device(seed)
- return seed
- def _seed_custom_device(seed) -> None:
- r"""Sets the seed to generate random numbers for custom device.
- Args:
- seed (int): The desired seed.
- See [Note: support the custom device with privateuse1]
- """
- seed = int(seed)
- custom_backend_name = torch._C._get_privateuse1_backend_name()
- if hasattr(torch, custom_backend_name):
- custom_device_mod = getattr(torch, custom_backend_name)
- _bad_fork_name = "_is_in_bad_fork"
- _seed_all_name = "manual_seed_all"
- if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
- if not getattr(custom_device_mod, _bad_fork_name)():
- getattr(custom_device_mod, _seed_all_name)(seed)
- else:
- message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
- message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
- warnings.warn(message, UserWarning, stacklevel=3)
- def initial_seed() -> int:
- r"""Returns the initial seed for generating random numbers as a
- Python `long`.
- .. note:: The returned seed is for the default generator on CPU only.
- """
- return default_generator.initial_seed()
- _fork_rng_warned_already = False
- @contextlib.contextmanager
- def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
- """
- Forks the RNG, so that when you return, the RNG is reset
- to the state that it was previously in.
- Args:
- devices (iterable of Device IDs): devices for which to fork
- the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
- on all devices, but will emit a warning if your machine has a lot
- of devices, since this function will run very slowly in that case.
- If you explicitly specify devices, this warning will be suppressed
- enabled (bool): if ``False``, the RNG is not forked. This is a convenience
- argument for easily disabling the context manager without having
- to delete it and unindent your Python code under it.
- device_type (str): device type str, default is `cuda`. As for custom device,
- see details in [Note: support the custom device with privateuse1]
- """
- device_type = torch.device(device_type).type
- device_mod = getattr(torch, device_type, None)
- if device_mod is None:
- raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
- "a module by `torch._register_device_module`.")
- global _fork_rng_warned_already
- # Internal arguments:
- # _caller: the function which called fork_rng, which the user used
- # _devices_kw: the devices keyword of _caller
- if not enabled:
- yield
- return
- if devices is None:
- num_devices = device_mod.device_count()
- if num_devices > 1 and not _fork_rng_warned_already:
- message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
- f"you have used {_caller} without explicitly specifying which devices are being used. "
- f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
- f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
- f" making use of a few {device_type.upper()} devices, set the environment variable "
- f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
- "with the set of devices you are actually using. For example, if you are using CPU only, "
- "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
- f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
- f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
- f"`range(torch.{device_type}.device_count())`.")
- warnings.warn(message)
- _fork_rng_warned_already = True
- devices = list(range(num_devices))
- else:
- # Protect against user passing us a generator; we need to traverse this
- # multiple times but a generator will be exhausted upon first traversal
- devices = list(devices)
- cpu_rng_state = torch.get_rng_state()
- device_rng_states = [device_mod.get_rng_state(device) for device in devices]
- try:
- yield
- finally:
- torch.set_rng_state(cpu_rng_state)
- for device, device_rng_state in zip(devices, device_rng_states):
- device_mod.set_rng_state(device_rng_state, device)
|