sampler.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import Tensor
  4. from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
  5. __all__ = [
  6. "BatchSampler",
  7. "RandomSampler",
  8. "Sampler",
  9. "SequentialSampler",
  10. "SubsetRandomSampler",
  11. "WeightedRandomSampler",
  12. ]
  13. T_co = TypeVar('T_co', covariant=True)
  14. class Sampler(Generic[T_co]):
  15. r"""Base class for all Samplers.
  16. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  17. way to iterate over indices or lists of indices (batches) of dataset elements,
  18. and may provide a :meth:`__len__` method that returns the length of the returned iterators.
  19. Args:
  20. data_source (Dataset): This argument is not used and will be removed in 2.2.0.
  21. You may still have custom implementation that utilizes it.
  22. Example:
  23. >>> # xdoctest: +SKIP
  24. >>> class AccedingSequenceLengthSampler(Sampler[int]):
  25. >>> def __init__(self, data: List[str]) -> None:
  26. >>> self.data = data
  27. >>>
  28. >>> def __len__(self) -> int:
  29. >>> return len(self.data)
  30. >>>
  31. >>> def __iter__(self) -> Iterator[int]:
  32. >>> sizes = torch.tensor([len(x) for x in self.data])
  33. >>> yield from torch.argsort(sizes).tolist()
  34. >>>
  35. >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
  36. >>> def __init__(self, data: List[str], batch_size: int) -> None:
  37. >>> self.data = data
  38. >>> self.batch_size = batch_size
  39. >>>
  40. >>> def __len__(self) -> int:
  41. >>> return (len(self.data) + self.batch_size - 1) // self.batch_size
  42. >>>
  43. >>> def __iter__(self) -> Iterator[List[int]]:
  44. >>> sizes = torch.tensor([len(x) for x in self.data])
  45. >>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
  46. >>> yield batch.tolist()
  47. .. note:: The :meth:`__len__` method isn't strictly required by
  48. :class:`~torch.utils.data.DataLoader`, but is expected in any
  49. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  50. """
  51. def __init__(self, data_source: Optional[Sized] = None) -> None:
  52. if data_source is not None:
  53. import warnings
  54. warnings.warn("`data_source` argument is not used and will be removed in 2.2.0."
  55. "You may still have custom implementation that utilizes it.")
  56. def __iter__(self) -> Iterator[T_co]:
  57. raise NotImplementedError
  58. # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  59. #
  60. # Many times we have an abstract class representing a collection/iterable of
  61. # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
  62. # implementing a `__len__` method. In such cases, we must make sure to not
  63. # provide a default implementation, because both straightforward default
  64. # implementations have their issues:
  65. #
  66. # + `return NotImplemented`:
  67. # Calling `len(subclass_instance)` raises:
  68. # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
  69. #
  70. # + `raise NotImplementedError`:
  71. # This prevents triggering some fallback behavior. E.g., the built-in
  72. # `list(X)` tries to call `len(X)` first, and executes a different code
  73. # path if the method is not found or `NotImplemented` is returned, while
  74. # raising a `NotImplementedError` will propagate and make the call fail
  75. # where it could have used `__iter__` to complete the call.
  76. #
  77. # Thus, the only two sensible things to do are
  78. #
  79. # + **not** provide a default `__len__`.
  80. #
  81. # + raise a `TypeError` instead, which is what Python uses when users call
  82. # a method that is not defined on an object.
  83. # (@ssnl verifies that this works on at least Python 3.7.)
  84. class SequentialSampler(Sampler[int]):
  85. r"""Samples elements sequentially, always in the same order.
  86. Args:
  87. data_source (Dataset): dataset to sample from
  88. """
  89. data_source: Sized
  90. def __init__(self, data_source: Sized) -> None:
  91. self.data_source = data_source
  92. def __iter__(self) -> Iterator[int]:
  93. return iter(range(len(self.data_source)))
  94. def __len__(self) -> int:
  95. return len(self.data_source)
  96. class RandomSampler(Sampler[int]):
  97. r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
  98. If with replacement, then user can specify :attr:`num_samples` to draw.
  99. Args:
  100. data_source (Dataset): dataset to sample from
  101. replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
  102. num_samples (int): number of samples to draw, default=`len(dataset)`.
  103. generator (Generator): Generator used in sampling.
  104. """
  105. data_source: Sized
  106. replacement: bool
  107. def __init__(self, data_source: Sized, replacement: bool = False,
  108. num_samples: Optional[int] = None, generator=None) -> None:
  109. self.data_source = data_source
  110. self.replacement = replacement
  111. self._num_samples = num_samples
  112. self.generator = generator
  113. if not isinstance(self.replacement, bool):
  114. raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
  115. if not isinstance(self.num_samples, int) or self.num_samples <= 0:
  116. raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
  117. @property
  118. def num_samples(self) -> int:
  119. # dataset size might change at runtime
  120. if self._num_samples is None:
  121. return len(self.data_source)
  122. return self._num_samples
  123. def __iter__(self) -> Iterator[int]:
  124. n = len(self.data_source)
  125. if self.generator is None:
  126. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  127. generator = torch.Generator()
  128. generator.manual_seed(seed)
  129. else:
  130. generator = self.generator
  131. if self.replacement:
  132. for _ in range(self.num_samples // 32):
  133. yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
  134. yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
  135. else:
  136. for _ in range(self.num_samples // n):
  137. yield from torch.randperm(n, generator=generator).tolist()
  138. yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
  139. def __len__(self) -> int:
  140. return self.num_samples
  141. class SubsetRandomSampler(Sampler[int]):
  142. r"""Samples elements randomly from a given list of indices, without replacement.
  143. Args:
  144. indices (sequence): a sequence of indices
  145. generator (Generator): Generator used in sampling.
  146. """
  147. indices: Sequence[int]
  148. def __init__(self, indices: Sequence[int], generator=None) -> None:
  149. self.indices = indices
  150. self.generator = generator
  151. def __iter__(self) -> Iterator[int]:
  152. for i in torch.randperm(len(self.indices), generator=self.generator):
  153. yield self.indices[i]
  154. def __len__(self) -> int:
  155. return len(self.indices)
  156. class WeightedRandomSampler(Sampler[int]):
  157. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
  158. Args:
  159. weights (sequence) : a sequence of weights, not necessary summing up to one
  160. num_samples (int): number of samples to draw
  161. replacement (bool): if ``True``, samples are drawn with replacement.
  162. If not, they are drawn without replacement, which means that when a
  163. sample index is drawn for a row, it cannot be drawn again for that row.
  164. generator (Generator): Generator used in sampling.
  165. Example:
  166. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  167. >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
  168. [4, 4, 1, 4, 5]
  169. >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
  170. [0, 1, 4, 3, 2]
  171. """
  172. weights: Tensor
  173. num_samples: int
  174. replacement: bool
  175. def __init__(self, weights: Sequence[float], num_samples: int,
  176. replacement: bool = True, generator=None) -> None:
  177. if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
  178. num_samples <= 0:
  179. raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")
  180. if not isinstance(replacement, bool):
  181. raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")
  182. weights_tensor = torch.as_tensor(weights, dtype=torch.double)
  183. if len(weights_tensor.shape) != 1:
  184. raise ValueError("weights should be a 1d sequence but given "
  185. f"weights have shape {tuple(weights_tensor.shape)}")
  186. self.weights = weights_tensor
  187. self.num_samples = num_samples
  188. self.replacement = replacement
  189. self.generator = generator
  190. def __iter__(self) -> Iterator[int]:
  191. rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
  192. yield from iter(rand_tensor.tolist())
  193. def __len__(self) -> int:
  194. return self.num_samples
  195. class BatchSampler(Sampler[List[int]]):
  196. r"""Wraps another sampler to yield a mini-batch of indices.
  197. Args:
  198. sampler (Sampler or Iterable): Base sampler. Can be any iterable object
  199. batch_size (int): Size of mini-batch.
  200. drop_last (bool): If ``True``, the sampler will drop the last batch if
  201. its size would be less than ``batch_size``
  202. Example:
  203. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
  204. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  205. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
  206. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  207. """
  208. def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
  209. # Since collections.abc.Iterable does not check for `__getitem__`, which
  210. # is one way for an object to be an iterable, we don't do an `isinstance`
  211. # check here.
  212. if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
  213. batch_size <= 0:
  214. raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
  215. if not isinstance(drop_last, bool):
  216. raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}")
  217. self.sampler = sampler
  218. self.batch_size = batch_size
  219. self.drop_last = drop_last
  220. def __iter__(self) -> Iterator[List[int]]:
  221. # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
  222. if self.drop_last:
  223. sampler_iter = iter(self.sampler)
  224. while True:
  225. try:
  226. batch = [next(sampler_iter) for _ in range(self.batch_size)]
  227. yield batch
  228. except StopIteration:
  229. break
  230. else:
  231. batch = [0] * self.batch_size
  232. idx_in_batch = 0
  233. for idx in self.sampler:
  234. batch[idx_in_batch] = idx
  235. idx_in_batch += 1
  236. if idx_in_batch == self.batch_size:
  237. yield batch
  238. idx_in_batch = 0
  239. batch = [0] * self.batch_size
  240. if idx_in_batch > 0:
  241. yield batch[:idx_in_batch]
  242. def __len__(self) -> int:
  243. # Can only be called if self.sampler has __len__ implemented
  244. # We cannot enforce this condition, so we turn off typechecking for the
  245. # implementation below.
  246. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  247. if self.drop_last:
  248. return len(self.sampler) // self.batch_size # type: ignore[arg-type]
  249. else:
  250. return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]