random.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # mypy: ignore-errors
  2. """Wrapper to mimic (parts of) np.random API surface.
  3. NumPy has strict guarantees on reproducibility etc; here we don't give any.
  4. Q: default dtype is float64 in numpy
  5. """
  6. from __future__ import annotations
  7. import functools
  8. from math import sqrt
  9. from typing import Optional
  10. import torch
  11. from . import _dtypes_impl, _util
  12. from ._normalizations import array_or_scalar, ArrayLike, normalizer
  13. __all__ = [
  14. "seed",
  15. "random_sample",
  16. "sample",
  17. "random",
  18. "rand",
  19. "randn",
  20. "normal",
  21. "choice",
  22. "randint",
  23. "shuffle",
  24. "uniform",
  25. ]
  26. def use_numpy_random():
  27. # local import to avoid ref cycles
  28. import torch._dynamo.config as config
  29. return config.use_numpy_random_stream
  30. def deco_stream(func):
  31. @functools.wraps(func)
  32. def inner(*args, **kwds):
  33. if not use_numpy_random():
  34. return func(*args, **kwds)
  35. else:
  36. import numpy
  37. from ._ndarray import ndarray
  38. f = getattr(numpy.random, func.__name__)
  39. # numpy funcs accept numpy ndarrays, unwrap
  40. args = tuple(
  41. arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args
  42. )
  43. kwds = {
  44. key: val.tensor.numpy() if isinstance(val, ndarray) else val
  45. for key, val in kwds.items()
  46. }
  47. value = f(*args, **kwds)
  48. # `value` can be either numpy.ndarray or python scalar (or None)
  49. if isinstance(value, numpy.ndarray):
  50. value = ndarray(torch.as_tensor(value))
  51. return value
  52. return inner
  53. @deco_stream
  54. def seed(seed=None):
  55. if seed is not None:
  56. torch.random.manual_seed(seed)
  57. @deco_stream
  58. def random_sample(size=None):
  59. if size is None:
  60. size = ()
  61. dtype = _dtypes_impl.default_dtypes().float_dtype
  62. values = torch.empty(size, dtype=dtype).uniform_()
  63. return array_or_scalar(values, return_scalar=size == ())
  64. def rand(*size):
  65. if size == ():
  66. size = None
  67. return random_sample(size)
  68. sample = random_sample
  69. random = random_sample
  70. @deco_stream
  71. def uniform(low=0.0, high=1.0, size=None):
  72. if size is None:
  73. size = ()
  74. dtype = _dtypes_impl.default_dtypes().float_dtype
  75. values = torch.empty(size, dtype=dtype).uniform_(low, high)
  76. return array_or_scalar(values, return_scalar=size == ())
  77. @deco_stream
  78. def randn(*size):
  79. dtype = _dtypes_impl.default_dtypes().float_dtype
  80. values = torch.randn(size, dtype=dtype)
  81. return array_or_scalar(values, return_scalar=size == ())
  82. @deco_stream
  83. def normal(loc=0.0, scale=1.0, size=None):
  84. if size is None:
  85. size = ()
  86. dtype = _dtypes_impl.default_dtypes().float_dtype
  87. values = torch.empty(size, dtype=dtype).normal_(loc, scale)
  88. return array_or_scalar(values, return_scalar=size == ())
  89. @deco_stream
  90. def shuffle(x):
  91. # no @normalizer because we do not cast e.g. lists to tensors
  92. from ._ndarray import ndarray
  93. if isinstance(x, torch.Tensor):
  94. tensor = x
  95. elif isinstance(x, ndarray):
  96. tensor = x.tensor
  97. else:
  98. raise NotImplementedError("We do not random.shuffle lists in-place")
  99. perm = torch.randperm(tensor.shape[0])
  100. xp = tensor[perm]
  101. tensor.copy_(xp)
  102. @deco_stream
  103. def randint(low, high=None, size=None):
  104. if size is None:
  105. size = ()
  106. if not isinstance(size, (tuple, list)):
  107. size = (size,)
  108. if high is None:
  109. low, high = 0, low
  110. values = torch.randint(low, high, size=size)
  111. return array_or_scalar(values, int, return_scalar=size == ())
  112. @deco_stream
  113. @normalizer
  114. def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
  115. # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
  116. if a.numel() == 1:
  117. a = torch.arange(a)
  118. # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
  119. # number of draws
  120. if size is None:
  121. num_el = 1
  122. elif _util.is_sequence(size):
  123. num_el = 1
  124. for el in size:
  125. num_el *= el
  126. else:
  127. num_el = size
  128. # prepare the probabilities
  129. if p is None:
  130. p = torch.ones_like(a) / a.shape[0]
  131. # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
  132. atol = sqrt(torch.finfo(p.dtype).eps)
  133. if abs(p.sum() - 1.0) > atol:
  134. raise ValueError("probabilities do not sum to 1.")
  135. # actually sample
  136. indices = torch.multinomial(p, num_el, replacement=replace)
  137. if _util.is_sequence(size):
  138. indices = indices.reshape(size)
  139. samples = a[indices]
  140. return samples