dataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # mypy: allow-untyped-defs
  2. import bisect
  3. import itertools
  4. import math
  5. import warnings
  6. from typing import (
  7. cast,
  8. Dict,
  9. Generic,
  10. Iterable,
  11. List,
  12. Optional,
  13. Sequence,
  14. Tuple,
  15. TypeVar,
  16. Union,
  17. )
  18. from typing_extensions import deprecated
  19. # No 'default_generator' in torch/__init__.pyi
  20. from torch import default_generator, randperm
  21. from ... import Generator, Tensor
  22. __all__ = [
  23. "Dataset",
  24. "IterableDataset",
  25. "TensorDataset",
  26. "StackDataset",
  27. "ConcatDataset",
  28. "ChainDataset",
  29. "Subset",
  30. "random_split",
  31. ]
  32. T_co = TypeVar("T_co", covariant=True)
  33. T = TypeVar("T")
  34. T_dict = Dict[str, T_co]
  35. T_tuple = Tuple[T_co, ...]
  36. T_stack = TypeVar("T_stack", T_tuple, T_dict)
  37. class Dataset(Generic[T_co]):
  38. r"""An abstract class representing a :class:`Dataset`.
  39. All datasets that represent a map from keys to data samples should subclass
  40. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  41. data sample for a given key. Subclasses could also optionally overwrite
  42. :meth:`__len__`, which is expected to return the size of the dataset by many
  43. :class:`~torch.utils.data.Sampler` implementations and the default options
  44. of :class:`~torch.utils.data.DataLoader`. Subclasses could also
  45. optionally implement :meth:`__getitems__`, for speedup batched samples
  46. loading. This method accepts list of indices of samples of batch and returns
  47. list of samples.
  48. .. note::
  49. :class:`~torch.utils.data.DataLoader` by default constructs an index
  50. sampler that yields integral indices. To make it work with a map-style
  51. dataset with non-integral indices/keys, a custom sampler must be provided.
  52. """
  53. def __getitem__(self, index) -> T_co:
  54. raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
  55. # def __getitems__(self, indices: List) -> List[T_co]:
  56. # Not implemented to prevent false-positives in fetcher check in
  57. # torch.utils.data._utils.fetch._MapDatasetFetcher
  58. def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
  59. return ConcatDataset([self, other])
  60. # No `def __len__(self)` default?
  61. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  62. # in pytorch/torch/utils/data/sampler.py
  63. class IterableDataset(Dataset[T_co], Iterable[T_co]):
  64. r"""An iterable Dataset.
  65. All datasets that represent an iterable of data samples should subclass it.
  66. Such form of datasets is particularly useful when data come from a stream.
  67. All subclasses should overwrite :meth:`__iter__`, which would return an
  68. iterator of samples in this dataset.
  69. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
  70. item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
  71. iterator. When :attr:`num_workers > 0`, each worker process will have a
  72. different copy of the dataset object, so it is often desired to configure
  73. each copy independently to avoid having duplicate data returned from the
  74. workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
  75. process, returns information about the worker. It can be used in either the
  76. dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
  77. :attr:`worker_init_fn` option to modify each copy's behavior.
  78. Example 1: splitting workload across all workers in :meth:`__iter__`::
  79. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  80. >>> # xdoctest: +SKIP("Fails on MacOS12")
  81. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  82. ... def __init__(self, start, end):
  83. ... super(MyIterableDataset).__init__()
  84. ... assert end > start, "this example code only works with end >= start"
  85. ... self.start = start
  86. ... self.end = end
  87. ...
  88. ... def __iter__(self):
  89. ... worker_info = torch.utils.data.get_worker_info()
  90. ... if worker_info is None: # single-process data loading, return the full iterator
  91. ... iter_start = self.start
  92. ... iter_end = self.end
  93. ... else: # in a worker process
  94. ... # split workload
  95. ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
  96. ... worker_id = worker_info.id
  97. ... iter_start = self.start + worker_id * per_worker
  98. ... iter_end = min(iter_start + per_worker, self.end)
  99. ... return iter(range(iter_start, iter_end))
  100. ...
  101. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  102. >>> ds = MyIterableDataset(start=3, end=7)
  103. >>> # Single-process loading
  104. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  105. [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
  106. >>> # xdoctest: +REQUIRES(POSIX)
  107. >>> # Mult-process loading with two worker processes
  108. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  109. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  110. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  111. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  112. >>> # With even more workers
  113. >>> # xdoctest: +IGNORE_WANT("non deterministic")
  114. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
  115. [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
  116. Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
  117. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
  118. >>> class MyIterableDataset(torch.utils.data.IterableDataset):
  119. ... def __init__(self, start, end):
  120. ... super(MyIterableDataset).__init__()
  121. ... assert end > start, "this example code only works with end >= start"
  122. ... self.start = start
  123. ... self.end = end
  124. ...
  125. ... def __iter__(self):
  126. ... return iter(range(self.start, self.end))
  127. ...
  128. >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
  129. >>> ds = MyIterableDataset(start=3, end=7)
  130. >>> # Single-process loading
  131. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
  132. [3, 4, 5, 6]
  133. >>>
  134. >>> # Directly doing multi-process loading yields duplicate data
  135. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
  136. [3, 3, 4, 4, 5, 5, 6, 6]
  137. >>> # Define a `worker_init_fn` that configures each dataset copy differently
  138. >>> def worker_init_fn(worker_id):
  139. ... worker_info = torch.utils.data.get_worker_info()
  140. ... dataset = worker_info.dataset # the dataset copy in this worker process
  141. ... overall_start = dataset.start
  142. ... overall_end = dataset.end
  143. ... # configure the dataset to only process the split workload
  144. ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
  145. ... worker_id = worker_info.id
  146. ... dataset.start = overall_start + worker_id * per_worker
  147. ... dataset.end = min(dataset.start + per_worker, overall_end)
  148. ...
  149. >>> # Mult-process loading with the custom `worker_init_fn`
  150. >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
  151. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
  152. [3, 5, 4, 6]
  153. >>> # With even more workers
  154. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
  155. [3, 4, 5, 6]
  156. """
  157. def __add__(self, other: Dataset[T_co]):
  158. return ChainDataset([self, other])
  159. # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
  160. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  161. class TensorDataset(Dataset[Tuple[Tensor, ...]]):
  162. r"""Dataset wrapping tensors.
  163. Each sample will be retrieved by indexing tensors along the first dimension.
  164. Args:
  165. *tensors (Tensor): tensors that have the same size of the first dimension.
  166. """
  167. tensors: Tuple[Tensor, ...]
  168. def __init__(self, *tensors: Tensor) -> None:
  169. assert all(
  170. tensors[0].size(0) == tensor.size(0) for tensor in tensors
  171. ), "Size mismatch between tensors"
  172. self.tensors = tensors
  173. def __getitem__(self, index):
  174. return tuple(tensor[index] for tensor in self.tensors)
  175. def __len__(self):
  176. return self.tensors[0].size(0)
  177. class StackDataset(Dataset[T_stack]):
  178. r"""Dataset as a stacking of multiple datasets.
  179. This class is useful to assemble different parts of complex input data, given as datasets.
  180. Example:
  181. >>> # xdoctest: +SKIP
  182. >>> images = ImageDataset()
  183. >>> texts = TextDataset()
  184. >>> tuple_stack = StackDataset(images, texts)
  185. >>> tuple_stack[0] == (images[0], texts[0])
  186. >>> dict_stack = StackDataset(image=images, text=texts)
  187. >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
  188. Args:
  189. *args (Dataset): Datasets for stacking returned as tuple.
  190. **kwargs (Dataset): Datasets for stacking returned as dict.
  191. """
  192. datasets: Union[tuple, dict]
  193. def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None:
  194. if args:
  195. if kwargs:
  196. raise ValueError(
  197. "Supported either ``tuple``- (via ``args``) or"
  198. "``dict``- (via ``kwargs``) like input/output, but both types are given."
  199. )
  200. self._length = len(args[0]) # type: ignore[arg-type]
  201. if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
  202. raise ValueError("Size mismatch between datasets")
  203. self.datasets = args
  204. elif kwargs:
  205. tmp = list(kwargs.values())
  206. self._length = len(tmp[0]) # type: ignore[arg-type]
  207. if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type]
  208. raise ValueError("Size mismatch between datasets")
  209. self.datasets = kwargs
  210. else:
  211. raise ValueError("At least one dataset should be passed")
  212. def __getitem__(self, index):
  213. if isinstance(self.datasets, dict):
  214. return {k: dataset[index] for k, dataset in self.datasets.items()}
  215. return tuple(dataset[index] for dataset in self.datasets)
  216. def __getitems__(self, indices: list):
  217. # add batched sampling support when parent datasets supports it.
  218. if isinstance(self.datasets, dict):
  219. dict_batch: List[T_dict] = [{} for _ in indices]
  220. for k, dataset in self.datasets.items():
  221. if callable(getattr(dataset, "__getitems__", None)):
  222. items = dataset.__getitems__(indices) # type: ignore[attr-defined]
  223. if len(items) != len(indices):
  224. raise ValueError(
  225. "Nested dataset's output size mismatch."
  226. f" Expected {len(indices)}, got {len(items)}"
  227. )
  228. for data, d_sample in zip(items, dict_batch):
  229. d_sample[k] = data
  230. else:
  231. for idx, d_sample in zip(indices, dict_batch):
  232. d_sample[k] = dataset[idx]
  233. return dict_batch
  234. # tuple data
  235. list_batch: List[list] = [[] for _ in indices]
  236. for dataset in self.datasets:
  237. if callable(getattr(dataset, "__getitems__", None)):
  238. items = dataset.__getitems__(indices) # type: ignore[attr-defined]
  239. if len(items) != len(indices):
  240. raise ValueError(
  241. "Nested dataset's output size mismatch."
  242. f" Expected {len(indices)}, got {len(items)}"
  243. )
  244. for data, t_sample in zip(items, list_batch):
  245. t_sample.append(data)
  246. else:
  247. for idx, t_sample in zip(indices, list_batch):
  248. t_sample.append(dataset[idx])
  249. tuple_batch: List[T_tuple] = [tuple(sample) for sample in list_batch]
  250. return tuple_batch
  251. def __len__(self):
  252. return self._length
  253. class ConcatDataset(Dataset[T_co]):
  254. r"""Dataset as a concatenation of multiple datasets.
  255. This class is useful to assemble different existing datasets.
  256. Args:
  257. datasets (sequence): List of datasets to be concatenated
  258. """
  259. datasets: List[Dataset[T_co]]
  260. cumulative_sizes: List[int]
  261. @staticmethod
  262. def cumsum(sequence):
  263. r, s = [], 0
  264. for e in sequence:
  265. l = len(e)
  266. r.append(l + s)
  267. s += l
  268. return r
  269. def __init__(self, datasets: Iterable[Dataset]) -> None:
  270. super().__init__()
  271. self.datasets = list(datasets)
  272. assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
  273. for d in self.datasets:
  274. assert not isinstance(
  275. d, IterableDataset
  276. ), "ConcatDataset does not support IterableDataset"
  277. self.cumulative_sizes = self.cumsum(self.datasets)
  278. def __len__(self):
  279. return self.cumulative_sizes[-1]
  280. def __getitem__(self, idx):
  281. if idx < 0:
  282. if -idx > len(self):
  283. raise ValueError(
  284. "absolute value of index should not exceed dataset length"
  285. )
  286. idx = len(self) + idx
  287. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
  288. if dataset_idx == 0:
  289. sample_idx = idx
  290. else:
  291. sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
  292. return self.datasets[dataset_idx][sample_idx]
  293. @property
  294. @deprecated(
  295. "`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
  296. category=FutureWarning,
  297. )
  298. def cummulative_sizes(self):
  299. return self.cumulative_sizes
  300. class ChainDataset(IterableDataset):
  301. r"""Dataset for chaining multiple :class:`IterableDataset` s.
  302. This class is useful to assemble different existing dataset streams. The
  303. chaining operation is done on-the-fly, so concatenating large-scale
  304. datasets with this class will be efficient.
  305. Args:
  306. datasets (iterable of IterableDataset): datasets to be chained together
  307. """
  308. def __init__(self, datasets: Iterable[Dataset]) -> None:
  309. super().__init__()
  310. self.datasets = datasets
  311. def __iter__(self):
  312. for d in self.datasets:
  313. assert isinstance(
  314. d, IterableDataset
  315. ), "ChainDataset only supports IterableDataset"
  316. yield from d
  317. def __len__(self):
  318. total = 0
  319. for d in self.datasets:
  320. assert isinstance(
  321. d, IterableDataset
  322. ), "ChainDataset only supports IterableDataset"
  323. total += len(d) # type: ignore[arg-type]
  324. return total
  325. class Subset(Dataset[T_co]):
  326. r"""
  327. Subset of a dataset at specified indices.
  328. Args:
  329. dataset (Dataset): The whole Dataset
  330. indices (sequence): Indices in the whole set selected for subset
  331. """
  332. dataset: Dataset[T_co]
  333. indices: Sequence[int]
  334. def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
  335. self.dataset = dataset
  336. self.indices = indices
  337. def __getitem__(self, idx):
  338. if isinstance(idx, list):
  339. return self.dataset[[self.indices[i] for i in idx]]
  340. return self.dataset[self.indices[idx]]
  341. def __getitems__(self, indices: List[int]) -> List[T_co]:
  342. # add batched sampling support when parent dataset supports it.
  343. # see torch.utils.data._utils.fetch._MapDatasetFetcher
  344. if callable(getattr(self.dataset, "__getitems__", None)):
  345. return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
  346. else:
  347. return [self.dataset[self.indices[idx]] for idx in indices]
  348. def __len__(self):
  349. return len(self.indices)
  350. def random_split(
  351. dataset: Dataset[T],
  352. lengths: Sequence[Union[int, float]],
  353. generator: Optional[Generator] = default_generator,
  354. ) -> List[Subset[T]]:
  355. r"""
  356. Randomly split a dataset into non-overlapping new datasets of given lengths.
  357. If a list of fractions that sum up to 1 is given,
  358. the lengths will be computed automatically as
  359. floor(frac * len(dataset)) for each fraction provided.
  360. After computing the lengths, if there are any remainders, 1 count will be
  361. distributed in round-robin fashion to the lengths
  362. until there are no remainders left.
  363. Optionally fix the generator for reproducible results, e.g.:
  364. Example:
  365. >>> # xdoctest: +SKIP
  366. >>> generator1 = torch.Generator().manual_seed(42)
  367. >>> generator2 = torch.Generator().manual_seed(42)
  368. >>> random_split(range(10), [3, 7], generator=generator1)
  369. >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
  370. Args:
  371. dataset (Dataset): Dataset to be split
  372. lengths (sequence): lengths or fractions of splits to be produced
  373. generator (Generator): Generator used for the random permutation.
  374. """
  375. if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
  376. subset_lengths: List[int] = []
  377. for i, frac in enumerate(lengths):
  378. if frac < 0 or frac > 1:
  379. raise ValueError(f"Fraction at index {i} is not between 0 and 1")
  380. n_items_in_split = int(
  381. math.floor(len(dataset) * frac) # type: ignore[arg-type]
  382. )
  383. subset_lengths.append(n_items_in_split)
  384. remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
  385. # add 1 to all the lengths in round-robin fashion until the remainder is 0
  386. for i in range(remainder):
  387. idx_to_add_at = i % len(subset_lengths)
  388. subset_lengths[idx_to_add_at] += 1
  389. lengths = subset_lengths
  390. for i, length in enumerate(lengths):
  391. if length == 0:
  392. warnings.warn(
  393. f"Length of split at index {i} is 0. "
  394. f"This might result in an empty dataset."
  395. )
  396. # Cannot verify that dataset is Sized
  397. if sum(lengths) != len(dataset): # type: ignore[arg-type]
  398. raise ValueError(
  399. "Sum of input lengths does not equal the length of the input dataset!"
  400. )
  401. indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
  402. lengths = cast(Sequence[int], lengths)
  403. return [
  404. Subset(dataset, indices[offset - length : offset])
  405. for offset, length in zip(itertools.accumulate(lengths), lengths)
  406. ]