grouping.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # mypy: allow-untyped-defs
  2. from torch.utils.data.datapipes._decorator import functional_datapipe
  3. from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk
  4. from typing import List, Sized, TypeVar
  5. __all__ = ["BatcherMapDataPipe", ]
  6. T = TypeVar('T')
  7. @functional_datapipe('batch')
  8. class BatcherMapDataPipe(MapDataPipe[DataChunk]):
  9. r"""
  10. Create mini-batches of data (functional name: ``batch``).
  11. An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``,
  12. or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``.
  13. Args:
  14. datapipe: Iterable DataPipe being batched
  15. batch_size: The size of each batch
  16. drop_last: Option to drop the last batch if it's not full
  17. Example:
  18. >>> # xdoctest: +SKIP
  19. >>> from torchdata.datapipes.map import SequenceWrapper
  20. >>> dp = SequenceWrapper(range(10))
  21. >>> batch_dp = dp.batch(batch_size=2)
  22. >>> list(batch_dp)
  23. [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
  24. """
  25. datapipe: MapDataPipe
  26. batch_size: int
  27. drop_last: bool
  28. def __init__(self,
  29. datapipe: MapDataPipe[T],
  30. batch_size: int,
  31. drop_last: bool = False,
  32. wrapper_class=DataChunk,
  33. ) -> None:
  34. assert batch_size > 0, "Batch size is required to be larger than 0!"
  35. super().__init__()
  36. self.datapipe = datapipe
  37. self.batch_size = batch_size
  38. self.drop_last = drop_last
  39. self.wrapper_class = wrapper_class
  40. def __getitem__(self, index) -> DataChunk:
  41. batch: List = []
  42. indices = range(index * self.batch_size, (index + 1) * self.batch_size)
  43. try:
  44. for i in indices:
  45. batch.append(self.datapipe[i])
  46. return self.wrapper_class(batch)
  47. except IndexError as e:
  48. if not self.drop_last and len(batch) > 0:
  49. return self.wrapper_class(batch)
  50. else:
  51. raise IndexError(f"Index {index} is out of bound.") from e
  52. def __len__(self) -> int:
  53. if isinstance(self.datapipe, Sized):
  54. if self.drop_last:
  55. return len(self.datapipe) // self.batch_size
  56. else:
  57. return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
  58. else:
  59. raise TypeError(f"{type(self).__name__} instance doesn't have valid length")