grouping.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections import defaultdict
  4. from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
  5. import torch.utils.data.datapipes.iter.sharding
  6. from torch.utils.data.datapipes._decorator import functional_datapipe
  7. from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe
  8. from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
  9. __all__ = [
  10. "BatcherIterDataPipe",
  11. "GrouperIterDataPipe",
  12. "UnBatcherIterDataPipe",
  13. ]
  14. T_co = TypeVar("T_co", covariant=True)
  15. def __getattr__(name: str):
  16. if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]:
  17. warnings.warn(f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1"
  18. f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`",
  19. category=FutureWarning, stacklevel=2)
  20. return getattr(torch.utils.data.datapipes.iter.sharding, name)
  21. raise AttributeError(f"module {__name__} has no attribute {name}")
  22. @functional_datapipe('batch')
  23. class BatcherIterDataPipe(IterDataPipe[DataChunk]):
  24. r"""
  25. Creates mini-batches of data (functional name: ``batch``).
  26. An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the
  27. last batch if ``drop_last`` is set to ``False``.
  28. Args:
  29. datapipe: Iterable DataPipe being batched
  30. batch_size: The size of each batch
  31. drop_last: Option to drop the last batch if it's not full
  32. wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding,
  33. defaults to ``DataChunk``
  34. Example:
  35. >>> # xdoctest: +SKIP
  36. >>> from torchdata.datapipes.iter import IterableWrapper
  37. >>> dp = IterableWrapper(range(10))
  38. >>> dp = dp.batch(batch_size=3, drop_last=True)
  39. >>> list(dp)
  40. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  41. """
  42. datapipe: IterDataPipe
  43. batch_size: int
  44. drop_last: bool
  45. def __init__(self,
  46. datapipe: IterDataPipe,
  47. batch_size: int,
  48. drop_last: bool = False,
  49. wrapper_class=DataChunk,
  50. ) -> None:
  51. assert batch_size > 0, "Batch size is required to be larger than 0!"
  52. super().__init__()
  53. self.datapipe = datapipe
  54. self.batch_size = batch_size
  55. self.drop_last = drop_last
  56. self.wrapper_class = wrapper_class
  57. def __iter__(self) -> Iterator[DataChunk]:
  58. batch: List = []
  59. for x in self.datapipe:
  60. batch.append(x)
  61. if len(batch) == self.batch_size:
  62. yield self.wrapper_class(batch)
  63. batch = []
  64. if len(batch) > 0:
  65. if not self.drop_last:
  66. yield self.wrapper_class(batch)
  67. def __len__(self) -> int:
  68. if isinstance(self.datapipe, Sized):
  69. if self.drop_last:
  70. return len(self.datapipe) // self.batch_size
  71. else:
  72. return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
  73. else:
  74. raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
  75. @functional_datapipe('unbatch')
  76. class UnBatcherIterDataPipe(IterDataPipe):
  77. r"""
  78. Undos batching of data (functional name: ``unbatch``).
  79. In other words, it flattens the data up to the specified level within a batched DataPipe.
  80. Args:
  81. datapipe: Iterable DataPipe being un-batched
  82. unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``,
  83. it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe.
  84. Example:
  85. >>> # xdoctest: +SKIP
  86. >>> from torchdata.datapipes.iter import IterableWrapper
  87. >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]])
  88. >>> dp1 = source_dp.unbatch()
  89. >>> list(dp1)
  90. [[0, 1], [2], [3, 4], [5], [6]]
  91. >>> dp2 = source_dp.unbatch(unbatch_level=2)
  92. >>> list(dp2)
  93. [0, 1, 2, 3, 4, 5, 6]
  94. """
  95. def __init__(self,
  96. datapipe: IterDataPipe,
  97. unbatch_level: int = 1):
  98. self.datapipe = datapipe
  99. self.unbatch_level = unbatch_level
  100. def __iter__(self):
  101. for element in self.datapipe:
  102. yield from self._dive(element, unbatch_level=self.unbatch_level)
  103. def _dive(self, element, unbatch_level):
  104. if unbatch_level < -1:
  105. raise ValueError("unbatch_level must be -1 or >= 0")
  106. if unbatch_level == -1:
  107. if isinstance(element, (list, DataChunk)):
  108. for item in element:
  109. yield from self._dive(item, unbatch_level=-1)
  110. else:
  111. yield element
  112. elif unbatch_level == 0:
  113. yield element
  114. else:
  115. if isinstance(element, (list, DataChunk)):
  116. for item in element:
  117. yield from self._dive(item, unbatch_level=unbatch_level - 1)
  118. else:
  119. raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
  120. @functional_datapipe('groupby')
  121. class GrouperIterDataPipe(IterDataPipe[DataChunk]):
  122. r"""
  123. Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``.
  124. (functional name: ``groupby``).
  125. The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group
  126. will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full,
  127. the DataPipe will yield the largest batch with the same key, provided that its size is larger
  128. than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``.
  129. After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity
  130. will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``.
  131. Args:
  132. datapipe: Iterable datapipe to be grouped
  133. group_key_fn: Function used to generate group key from the data of the source datapipe
  134. keep_key: Option to yield the matching key along with the items in a tuple,
  135. resulting in `(key, [items])` otherwise returning [items]
  136. buffer_size: The size of buffer for ungrouped data
  137. group_size: The max size of each group, a batch is yielded as soon as it reaches this size
  138. guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full
  139. drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer
  140. when the buffer is full
  141. Example:
  142. >>> import os
  143. >>> # xdoctest: +SKIP
  144. >>> from torchdata.datapipes.iter import IterableWrapper
  145. >>> def group_fn(file):
  146. ... return os.path.basename(file).split(".")[0]
  147. >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"])
  148. >>> dp0 = source_dp.groupby(group_key_fn=group_fn)
  149. >>> list(dp0)
  150. [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']]
  151. >>> # A group is yielded as soon as its size equals to `group_size`
  152. >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2)
  153. >>> list(dp1)
  154. [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
  155. >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size`
  156. >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2)
  157. >>> list(dp2)
  158. [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']]
  159. """
  160. def __init__(self,
  161. datapipe: IterDataPipe[T_co],
  162. group_key_fn: Callable[[T_co], Any],
  163. *,
  164. keep_key: bool = False,
  165. buffer_size: int = 10000,
  166. group_size: Optional[int] = None,
  167. guaranteed_group_size: Optional[int] = None,
  168. drop_remaining: bool = False):
  169. _check_unpickable_fn(group_key_fn)
  170. self.datapipe = datapipe
  171. self.group_key_fn = group_key_fn
  172. self.keep_key = keep_key
  173. self.max_buffer_size = buffer_size
  174. self.buffer_elements: DefaultDict[Any, List] = defaultdict(list)
  175. self.curr_buffer_size = 0
  176. self.group_size = group_size
  177. self.guaranteed_group_size = None
  178. if group_size is not None and buffer_size is not None:
  179. assert 0 < group_size <= buffer_size
  180. self.guaranteed_group_size = group_size
  181. if guaranteed_group_size is not None:
  182. assert group_size is not None and 0 < guaranteed_group_size <= group_size
  183. self.guaranteed_group_size = guaranteed_group_size
  184. self.drop_remaining = drop_remaining
  185. self.wrapper_class = DataChunk
  186. def _remove_biggest_key(self):
  187. biggest_key = None
  188. biggest_size = 0
  189. result_to_yield = None
  190. for findkey in self.buffer_elements.keys():
  191. if len(self.buffer_elements[findkey]) > biggest_size:
  192. biggest_size = len(self.buffer_elements[findkey])
  193. biggest_key = findkey
  194. if self.guaranteed_group_size is not None and biggest_size < self.guaranteed_group_size and not self.drop_remaining:
  195. raise RuntimeError('Failed to group items', str(self.buffer_elements[biggest_key]))
  196. if self.guaranteed_group_size is None or biggest_size >= self.guaranteed_group_size:
  197. result_to_yield = self.buffer_elements[biggest_key]
  198. self.curr_buffer_size -= biggest_size
  199. del self.buffer_elements[biggest_key]
  200. return result_to_yield
  201. def __iter__(self):
  202. for x in self.datapipe:
  203. key = self.group_key_fn(x)
  204. self.buffer_elements[key].append(x)
  205. self.curr_buffer_size += 1
  206. if self.group_size is not None and self.group_size == len(self.buffer_elements[key]):
  207. result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key])
  208. yield (key, result) if self.keep_key else result
  209. self.curr_buffer_size -= len(self.buffer_elements[key])
  210. del self.buffer_elements[key]
  211. if self.curr_buffer_size == self.max_buffer_size:
  212. result_to_yield = self._remove_biggest_key()
  213. if result_to_yield is not None:
  214. result = self.wrapper_class(result_to_yield)
  215. yield (key, result) if self.keep_key else result
  216. for key in tuple(self.buffer_elements.keys()):
  217. result = self.wrapper_class(self.buffer_elements.pop(key))
  218. self.curr_buffer_size -= len(result)
  219. yield (key, result) if self.keep_key else result
  220. def reset(self) -> None:
  221. self.curr_buffer_size = 0
  222. self.buffer_elements = defaultdict(list)
  223. def __getstate__(self):
  224. state = (
  225. self.datapipe,
  226. self.group_key_fn,
  227. self.keep_key,
  228. self.max_buffer_size,
  229. self.group_size,
  230. self.guaranteed_group_size,
  231. self.drop_remaining,
  232. self.wrapper_class,
  233. self._valid_iterator_id,
  234. self._number_of_samples_yielded,
  235. )
  236. if IterDataPipe.getstate_hook is not None:
  237. return IterDataPipe.getstate_hook(state)
  238. return state
  239. def __setstate__(self, state):
  240. (
  241. self.datapipe,
  242. self.group_key_fn,
  243. self.keep_key,
  244. self.max_buffer_size,
  245. self.group_size,
  246. self.guaranteed_group_size,
  247. self.drop_remaining,
  248. self.wrapper_class,
  249. self._valid_iterator_id,
  250. self._number_of_samples_yielded,
  251. ) = state
  252. self.curr_buffer_size = 0
  253. self.buffer_elements = defaultdict(list)
  254. def __del__(self):
  255. self.buffer_elements.clear()