random.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # mypy: allow-untyped-defs
  2. from typing import Iterable, List, Union
  3. import torch
  4. from .. import Tensor
  5. from . import _lazy_call, _lazy_init, current_device, device_count
  6. __all__ = [
  7. "get_rng_state",
  8. "get_rng_state_all",
  9. "set_rng_state",
  10. "set_rng_state_all",
  11. "manual_seed",
  12. "manual_seed_all",
  13. "seed",
  14. "seed_all",
  15. "initial_seed",
  16. ]
  17. def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
  18. r"""Return the random number generator state of the specified GPU as a ByteTensor.
  19. Args:
  20. device (torch.device or int, optional): The device to return the RNG state of.
  21. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  22. .. warning::
  23. This function eagerly initializes CUDA.
  24. """
  25. _lazy_init()
  26. if isinstance(device, str):
  27. device = torch.device(device)
  28. elif isinstance(device, int):
  29. device = torch.device("cuda", device)
  30. idx = device.index
  31. if idx is None:
  32. idx = current_device()
  33. default_generator = torch.cuda.default_generators[idx]
  34. return default_generator.get_state()
  35. def get_rng_state_all() -> List[Tensor]:
  36. r"""Return a list of ByteTensor representing the random number states of all devices."""
  37. results = []
  38. for i in range(device_count()):
  39. results.append(get_rng_state(i))
  40. return results
  41. def set_rng_state(
  42. new_state: Tensor, device: Union[int, str, torch.device] = "cuda"
  43. ) -> None:
  44. r"""Set the random number generator state of the specified GPU.
  45. Args:
  46. new_state (torch.ByteTensor): The desired state
  47. device (torch.device or int, optional): The device to set the RNG state.
  48. Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
  49. """
  50. with torch._C._DisableFuncTorch():
  51. new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
  52. if isinstance(device, str):
  53. device = torch.device(device)
  54. elif isinstance(device, int):
  55. device = torch.device("cuda", device)
  56. def cb():
  57. idx = device.index
  58. if idx is None:
  59. idx = current_device()
  60. default_generator = torch.cuda.default_generators[idx]
  61. default_generator.set_state(new_state_copy)
  62. _lazy_call(cb)
  63. def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
  64. r"""Set the random number generator state of all devices.
  65. Args:
  66. new_states (Iterable of torch.ByteTensor): The desired state for each device.
  67. """
  68. for i, state in enumerate(new_states):
  69. set_rng_state(state, i)
  70. def manual_seed(seed: int) -> None:
  71. r"""Set the seed for generating random numbers for the current GPU.
  72. It's safe to call this function if CUDA is not available; in that
  73. case, it is silently ignored.
  74. Args:
  75. seed (int): The desired seed.
  76. .. warning::
  77. If you are working with a multi-GPU model, this function is insufficient
  78. to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
  79. """
  80. seed = int(seed)
  81. def cb():
  82. idx = current_device()
  83. default_generator = torch.cuda.default_generators[idx]
  84. default_generator.manual_seed(seed)
  85. _lazy_call(cb, seed=True)
  86. def manual_seed_all(seed: int) -> None:
  87. r"""Set the seed for generating random numbers on all GPUs.
  88. It's safe to call this function if CUDA is not available; in that
  89. case, it is silently ignored.
  90. Args:
  91. seed (int): The desired seed.
  92. """
  93. seed = int(seed)
  94. def cb():
  95. for i in range(device_count()):
  96. default_generator = torch.cuda.default_generators[i]
  97. default_generator.manual_seed(seed)
  98. _lazy_call(cb, seed_all=True)
  99. def seed() -> None:
  100. r"""Set the seed for generating random numbers to a random number for the current GPU.
  101. It's safe to call this function if CUDA is not available; in that
  102. case, it is silently ignored.
  103. .. warning::
  104. If you are working with a multi-GPU model, this function will only initialize
  105. the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
  106. """
  107. def cb():
  108. idx = current_device()
  109. default_generator = torch.cuda.default_generators[idx]
  110. default_generator.seed()
  111. _lazy_call(cb)
  112. def seed_all() -> None:
  113. r"""Set the seed for generating random numbers to a random number on all GPUs.
  114. It's safe to call this function if CUDA is not available; in that
  115. case, it is silently ignored.
  116. """
  117. def cb():
  118. random_seed = 0
  119. seeded = False
  120. for i in range(device_count()):
  121. default_generator = torch.cuda.default_generators[i]
  122. if not seeded:
  123. default_generator.seed()
  124. random_seed = default_generator.initial_seed()
  125. seeded = True
  126. else:
  127. default_generator.manual_seed(random_seed)
  128. _lazy_call(cb)
  129. def initial_seed() -> int:
  130. r"""Return the current random seed of the current GPU.
  131. .. warning::
  132. This function eagerly initializes CUDA.
  133. """
  134. _lazy_init()
  135. idx = current_device()
  136. default_generator = torch.cuda.default_generators[idx]
  137. return default_generator.initial_seed()