| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489 |
- # mypy: allow-untyped-defs
- import bisect
- import itertools
- import math
- import warnings
- from typing import (
- cast,
- Dict,
- Generic,
- Iterable,
- List,
- Optional,
- Sequence,
- Tuple,
- TypeVar,
- Union,
- )
- from typing_extensions import deprecated
- # No 'default_generator' in torch/__init__.pyi
- from torch import default_generator, randperm
- from ... import Generator, Tensor
- __all__ = [
- "Dataset",
- "IterableDataset",
- "TensorDataset",
- "StackDataset",
- "ConcatDataset",
- "ChainDataset",
- "Subset",
- "random_split",
- ]
- T_co = TypeVar("T_co", covariant=True)
- T = TypeVar("T")
- T_dict = Dict[str, T_co]
- T_tuple = Tuple[T_co, ...]
- T_stack = TypeVar("T_stack", T_tuple, T_dict)
- class Dataset(Generic[T_co]):
- r"""An abstract class representing a :class:`Dataset`.
- All datasets that represent a map from keys to data samples should subclass
- it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
- data sample for a given key. Subclasses could also optionally overwrite
- :meth:`__len__`, which is expected to return the size of the dataset by many
- :class:`~torch.utils.data.Sampler` implementations and the default options
- of :class:`~torch.utils.data.DataLoader`. Subclasses could also
- optionally implement :meth:`__getitems__`, for speedup batched samples
- loading. This method accepts list of indices of samples of batch and returns
- list of samples.
- .. note::
- :class:`~torch.utils.data.DataLoader` by default constructs an index
- sampler that yields integral indices. To make it work with a map-style
- dataset with non-integral indices/keys, a custom sampler must be provided.
- """
- def __getitem__(self, index) -> T_co:
- raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
- # def __getitems__(self, indices: List) -> List[T_co]:
- # Not implemented to prevent false-positives in fetcher check in
- # torch.utils.data._utils.fetch._MapDatasetFetcher
- def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
- return ConcatDataset([self, other])
- # No `def __len__(self)` default?
- # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- # in pytorch/torch/utils/data/sampler.py
- class IterableDataset(Dataset[T_co], Iterable[T_co]):
- r"""An iterable Dataset.
- All datasets that represent an iterable of data samples should subclass it.
- Such form of datasets is particularly useful when data come from a stream.
- All subclasses should overwrite :meth:`__iter__`, which would return an
- iterator of samples in this dataset.
- When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
- item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
- iterator. When :attr:`num_workers > 0`, each worker process will have a
- different copy of the dataset object, so it is often desired to configure
- each copy independently to avoid having duplicate data returned from the
- workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
- process, returns information about the worker. It can be used in either the
- dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
- :attr:`worker_init_fn` option to modify each copy's behavior.
- Example 1: splitting workload across all workers in :meth:`__iter__`::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
- >>> # xdoctest: +SKIP("Fails on MacOS12")
- >>> class MyIterableDataset(torch.utils.data.IterableDataset):
- ... def __init__(self, start, end):
- ... super(MyIterableDataset).__init__()
- ... assert end > start, "this example code only works with end >= start"
- ... self.start = start
- ... self.end = end
- ...
- ... def __iter__(self):
- ... worker_info = torch.utils.data.get_worker_info()
- ... if worker_info is None: # single-process data loading, return the full iterator
- ... iter_start = self.start
- ... iter_end = self.end
- ... else: # in a worker process
- ... # split workload
- ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
- ... worker_id = worker_info.id
- ... iter_start = self.start + worker_id * per_worker
- ... iter_end = min(iter_start + per_worker, self.end)
- ... return iter(range(iter_start, iter_end))
- ...
- >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
- >>> ds = MyIterableDataset(start=3, end=7)
- >>> # Single-process loading
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
- [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
- >>> # xdoctest: +REQUIRES(POSIX)
- >>> # Mult-process loading with two worker processes
- >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
- >>> # xdoctest: +IGNORE_WANT("non deterministic")
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
- [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
- >>> # With even more workers
- >>> # xdoctest: +IGNORE_WANT("non deterministic")
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
- [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
- Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
- >>> class MyIterableDataset(torch.utils.data.IterableDataset):
- ... def __init__(self, start, end):
- ... super(MyIterableDataset).__init__()
- ... assert end > start, "this example code only works with end >= start"
- ... self.start = start
- ... self.end = end
- ...
- ... def __iter__(self):
- ... return iter(range(self.start, self.end))
- ...
- >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
- >>> ds = MyIterableDataset(start=3, end=7)
- >>> # Single-process loading
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
- [3, 4, 5, 6]
- >>>
- >>> # Directly doing multi-process loading yields duplicate data
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
- [3, 3, 4, 4, 5, 5, 6, 6]
- >>> # Define a `worker_init_fn` that configures each dataset copy differently
- >>> def worker_init_fn(worker_id):
- ... worker_info = torch.utils.data.get_worker_info()
- ... dataset = worker_info.dataset # the dataset copy in this worker process
- ... overall_start = dataset.start
- ... overall_end = dataset.end
- ... # configure the dataset to only process the split workload
- ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
- ... worker_id = worker_info.id
- ... dataset.start = overall_start + worker_id * per_worker
- ... dataset.end = min(dataset.start + per_worker, overall_end)
- ...
- >>> # Mult-process loading with the custom `worker_init_fn`
- >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
- [3, 5, 4, 6]
- >>> # With even more workers
- >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
- [3, 4, 5, 6]
- """
- def __add__(self, other: Dataset[T_co]):
- return ChainDataset([self, other])
- # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
- # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- class TensorDataset(Dataset[Tuple[Tensor, ...]]):
- r"""Dataset wrapping tensors.
- Each sample will be retrieved by indexing tensors along the first dimension.
- Args:
- *tensors (Tensor): tensors that have the same size of the first dimension.
- """
- tensors: Tuple[Tensor, ...]
- def __init__(self, *tensors: Tensor) -> None:
- assert all(
- tensors[0].size(0) == tensor.size(0) for tensor in tensors
- ), "Size mismatch between tensors"
- self.tensors = tensors
- def __getitem__(self, index):
- return tuple(tensor[index] for tensor in self.tensors)
- def __len__(self):
- return self.tensors[0].size(0)
- class StackDataset(Dataset[T_stack]):
- r"""Dataset as a stacking of multiple datasets.
- This class is useful to assemble different parts of complex input data, given as datasets.
- Example:
- >>> # xdoctest: +SKIP
- >>> images = ImageDataset()
- >>> texts = TextDataset()
- >>> tuple_stack = StackDataset(images, texts)
- >>> tuple_stack[0] == (images[0], texts[0])
- >>> dict_stack = StackDataset(image=images, text=texts)
- >>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
- Args:
- *args (Dataset): Datasets for stacking returned as tuple.
- **kwargs (Dataset): Datasets for stacking returned as dict.
- """
- datasets: Union[tuple, dict]
- def __init__(self, *args: Dataset[T_co], **kwargs: Dataset[T_co]) -> None:
- if args:
- if kwargs:
- raise ValueError(
- "Supported either ``tuple``- (via ``args``) or"
- "``dict``- (via ``kwargs``) like input/output, but both types are given."
- )
- self._length = len(args[0]) # type: ignore[arg-type]
- if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
- raise ValueError("Size mismatch between datasets")
- self.datasets = args
- elif kwargs:
- tmp = list(kwargs.values())
- self._length = len(tmp[0]) # type: ignore[arg-type]
- if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type]
- raise ValueError("Size mismatch between datasets")
- self.datasets = kwargs
- else:
- raise ValueError("At least one dataset should be passed")
- def __getitem__(self, index):
- if isinstance(self.datasets, dict):
- return {k: dataset[index] for k, dataset in self.datasets.items()}
- return tuple(dataset[index] for dataset in self.datasets)
- def __getitems__(self, indices: list):
- # add batched sampling support when parent datasets supports it.
- if isinstance(self.datasets, dict):
- dict_batch: List[T_dict] = [{} for _ in indices]
- for k, dataset in self.datasets.items():
- if callable(getattr(dataset, "__getitems__", None)):
- items = dataset.__getitems__(indices) # type: ignore[attr-defined]
- if len(items) != len(indices):
- raise ValueError(
- "Nested dataset's output size mismatch."
- f" Expected {len(indices)}, got {len(items)}"
- )
- for data, d_sample in zip(items, dict_batch):
- d_sample[k] = data
- else:
- for idx, d_sample in zip(indices, dict_batch):
- d_sample[k] = dataset[idx]
- return dict_batch
- # tuple data
- list_batch: List[list] = [[] for _ in indices]
- for dataset in self.datasets:
- if callable(getattr(dataset, "__getitems__", None)):
- items = dataset.__getitems__(indices) # type: ignore[attr-defined]
- if len(items) != len(indices):
- raise ValueError(
- "Nested dataset's output size mismatch."
- f" Expected {len(indices)}, got {len(items)}"
- )
- for data, t_sample in zip(items, list_batch):
- t_sample.append(data)
- else:
- for idx, t_sample in zip(indices, list_batch):
- t_sample.append(dataset[idx])
- tuple_batch: List[T_tuple] = [tuple(sample) for sample in list_batch]
- return tuple_batch
- def __len__(self):
- return self._length
- class ConcatDataset(Dataset[T_co]):
- r"""Dataset as a concatenation of multiple datasets.
- This class is useful to assemble different existing datasets.
- Args:
- datasets (sequence): List of datasets to be concatenated
- """
- datasets: List[Dataset[T_co]]
- cumulative_sizes: List[int]
- @staticmethod
- def cumsum(sequence):
- r, s = [], 0
- for e in sequence:
- l = len(e)
- r.append(l + s)
- s += l
- return r
- def __init__(self, datasets: Iterable[Dataset]) -> None:
- super().__init__()
- self.datasets = list(datasets)
- assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
- for d in self.datasets:
- assert not isinstance(
- d, IterableDataset
- ), "ConcatDataset does not support IterableDataset"
- self.cumulative_sizes = self.cumsum(self.datasets)
- def __len__(self):
- return self.cumulative_sizes[-1]
- def __getitem__(self, idx):
- if idx < 0:
- if -idx > len(self):
- raise ValueError(
- "absolute value of index should not exceed dataset length"
- )
- idx = len(self) + idx
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
- if dataset_idx == 0:
- sample_idx = idx
- else:
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
- return self.datasets[dataset_idx][sample_idx]
- @property
- @deprecated(
- "`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
- category=FutureWarning,
- )
- def cummulative_sizes(self):
- return self.cumulative_sizes
- class ChainDataset(IterableDataset):
- r"""Dataset for chaining multiple :class:`IterableDataset` s.
- This class is useful to assemble different existing dataset streams. The
- chaining operation is done on-the-fly, so concatenating large-scale
- datasets with this class will be efficient.
- Args:
- datasets (iterable of IterableDataset): datasets to be chained together
- """
- def __init__(self, datasets: Iterable[Dataset]) -> None:
- super().__init__()
- self.datasets = datasets
- def __iter__(self):
- for d in self.datasets:
- assert isinstance(
- d, IterableDataset
- ), "ChainDataset only supports IterableDataset"
- yield from d
- def __len__(self):
- total = 0
- for d in self.datasets:
- assert isinstance(
- d, IterableDataset
- ), "ChainDataset only supports IterableDataset"
- total += len(d) # type: ignore[arg-type]
- return total
- class Subset(Dataset[T_co]):
- r"""
- Subset of a dataset at specified indices.
- Args:
- dataset (Dataset): The whole Dataset
- indices (sequence): Indices in the whole set selected for subset
- """
- dataset: Dataset[T_co]
- indices: Sequence[int]
- def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
- self.dataset = dataset
- self.indices = indices
- def __getitem__(self, idx):
- if isinstance(idx, list):
- return self.dataset[[self.indices[i] for i in idx]]
- return self.dataset[self.indices[idx]]
- def __getitems__(self, indices: List[int]) -> List[T_co]:
- # add batched sampling support when parent dataset supports it.
- # see torch.utils.data._utils.fetch._MapDatasetFetcher
- if callable(getattr(self.dataset, "__getitems__", None)):
- return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
- else:
- return [self.dataset[self.indices[idx]] for idx in indices]
- def __len__(self):
- return len(self.indices)
- def random_split(
- dataset: Dataset[T],
- lengths: Sequence[Union[int, float]],
- generator: Optional[Generator] = default_generator,
- ) -> List[Subset[T]]:
- r"""
- Randomly split a dataset into non-overlapping new datasets of given lengths.
- If a list of fractions that sum up to 1 is given,
- the lengths will be computed automatically as
- floor(frac * len(dataset)) for each fraction provided.
- After computing the lengths, if there are any remainders, 1 count will be
- distributed in round-robin fashion to the lengths
- until there are no remainders left.
- Optionally fix the generator for reproducible results, e.g.:
- Example:
- >>> # xdoctest: +SKIP
- >>> generator1 = torch.Generator().manual_seed(42)
- >>> generator2 = torch.Generator().manual_seed(42)
- >>> random_split(range(10), [3, 7], generator=generator1)
- >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
- Args:
- dataset (Dataset): Dataset to be split
- lengths (sequence): lengths or fractions of splits to be produced
- generator (Generator): Generator used for the random permutation.
- """
- if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
- subset_lengths: List[int] = []
- for i, frac in enumerate(lengths):
- if frac < 0 or frac > 1:
- raise ValueError(f"Fraction at index {i} is not between 0 and 1")
- n_items_in_split = int(
- math.floor(len(dataset) * frac) # type: ignore[arg-type]
- )
- subset_lengths.append(n_items_in_split)
- remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
- # add 1 to all the lengths in round-robin fashion until the remainder is 0
- for i in range(remainder):
- idx_to_add_at = i % len(subset_lengths)
- subset_lengths[idx_to_add_at] += 1
- lengths = subset_lengths
- for i, length in enumerate(lengths):
- if length == 0:
- warnings.warn(
- f"Length of split at index {i} is 0. "
- f"This might result in an empty dataset."
- )
- # Cannot verify that dataset is Sized
- if sum(lengths) != len(dataset): # type: ignore[arg-type]
- raise ValueError(
- "Sum of input lengths does not equal the length of the input dataset!"
- )
- indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
- lengths = cast(Sequence[int], lengths)
- return [
- Subset(dataset, indices[offset - length : offset])
- for offset, length in zip(itertools.accumulate(lengths), lengths)
- ]
|