| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # mypy: ignore-errors
- """Wrapper to mimic (parts of) np.random API surface.
- NumPy has strict guarantees on reproducibility etc; here we don't give any.
- Q: default dtype is float64 in numpy
- """
- from __future__ import annotations
- import functools
- from math import sqrt
- from typing import Optional
- import torch
- from . import _dtypes_impl, _util
- from ._normalizations import array_or_scalar, ArrayLike, normalizer
- __all__ = [
- "seed",
- "random_sample",
- "sample",
- "random",
- "rand",
- "randn",
- "normal",
- "choice",
- "randint",
- "shuffle",
- "uniform",
- ]
- def use_numpy_random():
- # local import to avoid ref cycles
- import torch._dynamo.config as config
- return config.use_numpy_random_stream
- def deco_stream(func):
- @functools.wraps(func)
- def inner(*args, **kwds):
- if not use_numpy_random():
- return func(*args, **kwds)
- else:
- import numpy
- from ._ndarray import ndarray
- f = getattr(numpy.random, func.__name__)
- # numpy funcs accept numpy ndarrays, unwrap
- args = tuple(
- arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args
- )
- kwds = {
- key: val.tensor.numpy() if isinstance(val, ndarray) else val
- for key, val in kwds.items()
- }
- value = f(*args, **kwds)
- # `value` can be either numpy.ndarray or python scalar (or None)
- if isinstance(value, numpy.ndarray):
- value = ndarray(torch.as_tensor(value))
- return value
- return inner
- @deco_stream
- def seed(seed=None):
- if seed is not None:
- torch.random.manual_seed(seed)
- @deco_stream
- def random_sample(size=None):
- if size is None:
- size = ()
- dtype = _dtypes_impl.default_dtypes().float_dtype
- values = torch.empty(size, dtype=dtype).uniform_()
- return array_or_scalar(values, return_scalar=size == ())
- def rand(*size):
- if size == ():
- size = None
- return random_sample(size)
- sample = random_sample
- random = random_sample
- @deco_stream
- def uniform(low=0.0, high=1.0, size=None):
- if size is None:
- size = ()
- dtype = _dtypes_impl.default_dtypes().float_dtype
- values = torch.empty(size, dtype=dtype).uniform_(low, high)
- return array_or_scalar(values, return_scalar=size == ())
- @deco_stream
- def randn(*size):
- dtype = _dtypes_impl.default_dtypes().float_dtype
- values = torch.randn(size, dtype=dtype)
- return array_or_scalar(values, return_scalar=size == ())
- @deco_stream
- def normal(loc=0.0, scale=1.0, size=None):
- if size is None:
- size = ()
- dtype = _dtypes_impl.default_dtypes().float_dtype
- values = torch.empty(size, dtype=dtype).normal_(loc, scale)
- return array_or_scalar(values, return_scalar=size == ())
- @deco_stream
- def shuffle(x):
- # no @normalizer because we do not cast e.g. lists to tensors
- from ._ndarray import ndarray
- if isinstance(x, torch.Tensor):
- tensor = x
- elif isinstance(x, ndarray):
- tensor = x.tensor
- else:
- raise NotImplementedError("We do not random.shuffle lists in-place")
- perm = torch.randperm(tensor.shape[0])
- xp = tensor[perm]
- tensor.copy_(xp)
- @deco_stream
- def randint(low, high=None, size=None):
- if size is None:
- size = ()
- if not isinstance(size, (tuple, list)):
- size = (size,)
- if high is None:
- low, high = 0, low
- values = torch.randint(low, high, size=size)
- return array_or_scalar(values, int, return_scalar=size == ())
- @deco_stream
- @normalizer
- def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
- # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
- if a.numel() == 1:
- a = torch.arange(a)
- # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
- # number of draws
- if size is None:
- num_el = 1
- elif _util.is_sequence(size):
- num_el = 1
- for el in size:
- num_el *= el
- else:
- num_el = size
- # prepare the probabilities
- if p is None:
- p = torch.ones_like(a) / a.shape[0]
- # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
- atol = sqrt(torch.finfo(p.dtype).eps)
- if abs(p.sum() - 1.0) > atol:
- raise ValueError("probabilities do not sum to 1.")
- # actually sample
- indices = torch.multinomial(p, num_el, replacement=replace)
- if _util.is_sequence(size):
- indices = indices.reshape(size)
- samples = a[indices]
- return samples
|