pt_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import numpy as np
  2. import torch
  3. from torch.utils.data import Dataset, IterableDataset
  4. from ..utils.generic import ModelOutput
  5. class PipelineDataset(Dataset):
  6. def __init__(self, dataset, process, params):
  7. self.dataset = dataset
  8. self.process = process
  9. self.params = params
  10. def __len__(self):
  11. return len(self.dataset)
  12. def __getitem__(self, i):
  13. item = self.dataset[i]
  14. processed = self.process(item, **self.params)
  15. return processed
  16. class PipelineIterator(IterableDataset):
  17. def __init__(self, loader, infer, params, loader_batch_size=None):
  18. """
  19. Roughly equivalent to
  20. ```
  21. for item in loader:
  22. yield infer(item, **params)
  23. ```
  24. Arguments:
  25. loader (`torch.utils.data.DataLoader` or `Iterable`):
  26. The iterator that will be used to apply `infer` on.
  27. infer (any function):
  28. The function to apply of each element of `loader`.
  29. params (`dict`):
  30. The parameters passed to `infer` along with every item
  31. loader_batch_size (`int`, *optional*):
  32. If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
  33. making it roughly behave as
  34. ```
  35. for items in loader:
  36. for i in loader_batch_size:
  37. item = items[i]
  38. yield infer(item, **params)
  39. ```"""
  40. self.loader = loader
  41. self.infer = infer
  42. self.params = params
  43. if loader_batch_size == 1:
  44. # Let's spare some time by deactivating altogether
  45. loader_batch_size = None
  46. self.loader_batch_size = loader_batch_size
  47. # Internal bookkeeping
  48. self._loader_batch_index = None
  49. self._loader_batch_data = None
  50. def __len__(self):
  51. return len(self.loader)
  52. def __iter__(self):
  53. self.iterator = iter(self.loader)
  54. return self
  55. def loader_batch_item(self):
  56. """
  57. Return item located at `loader_batch_index` within the current `loader_batch_data`.
  58. """
  59. if isinstance(self._loader_batch_data, torch.Tensor):
  60. # Batch data is simple tensor, just fetch the slice
  61. result = self._loader_batch_data[self._loader_batch_index].unsqueeze(0)
  62. else:
  63. # Batch data is assumed to be BaseModelOutput (or dict)
  64. loader_batched = {}
  65. for k, element in self._loader_batch_data.items():
  66. if isinstance(element, ModelOutput):
  67. # Convert ModelOutput to tuple first
  68. element = element.to_tuple()
  69. if isinstance(element[0], torch.Tensor):
  70. loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
  71. elif isinstance(element[0], np.ndarray):
  72. loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
  73. continue
  74. if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
  75. # Those are stored as lists of tensors so need specific unbatching.
  76. if isinstance(element[0], torch.Tensor):
  77. loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
  78. elif isinstance(element[0], np.ndarray):
  79. loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
  80. continue
  81. if element is None:
  82. # This can happen for optional data that get passed around
  83. loader_batched[k] = None
  84. elif isinstance(element[self._loader_batch_index], torch.Tensor):
  85. # Take correct batch data, but make it looked like batch_size=1
  86. # For compatibility with other methods within transformers
  87. loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
  88. elif isinstance(element[self._loader_batch_index], np.ndarray):
  89. # Take correct batch data, but make it looked like batch_size=1
  90. # For compatibility with other methods within transformers
  91. loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
  92. else:
  93. # This is typically a list, so no need to `unsqueeze`.
  94. loader_batched[k] = element[self._loader_batch_index]
  95. # Recreate the element by reusing the original class to make it look
  96. # batch_size=1
  97. result = self._loader_batch_data.__class__(loader_batched)
  98. self._loader_batch_index += 1
  99. return result
  100. def __next__(self):
  101. if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
  102. # We are currently unrolling a batch so we just need to return
  103. # the current item within a batch
  104. return self.loader_batch_item()
  105. # We're out of items within a batch
  106. item = next(self.iterator)
  107. processed = self.infer(item, **self.params)
  108. # We now have a batch of "inferred things".
  109. if self.loader_batch_size is not None:
  110. # Try to infer the size of the batch
  111. if isinstance(processed, torch.Tensor):
  112. first_tensor = processed
  113. elif isinstance(processed, tuple):
  114. first_tensor = processed[0]
  115. else:
  116. key = list(processed.keys())[0]
  117. first_tensor = processed[key]
  118. if isinstance(first_tensor, list):
  119. observed_batch_size = len(first_tensor)
  120. else:
  121. observed_batch_size = first_tensor.shape[0]
  122. if 0 < observed_batch_size < self.loader_batch_size:
  123. # could be last batch so we can't unroll as many
  124. # elements.
  125. self.loader_batch_size = observed_batch_size
  126. # Setting internal index to unwrap the batch
  127. self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
  128. self._loader_batch_index = 0
  129. return self.loader_batch_item()
  130. else:
  131. # We're not unrolling batches
  132. return processed
  133. class PipelineChunkIterator(PipelineIterator):
  134. def __init__(self, loader, infer, params, loader_batch_size=None):
  135. """
  136. Roughly equivalent to
  137. ```
  138. for iterator in loader:
  139. for item in iterator:
  140. yield infer(item, **params)
  141. ```
  142. Arguments:
  143. loader (`torch.utils.data.DataLoader` or `Iterable`):
  144. The iterator that will be used to apply `infer` on.
  145. infer (any function):
  146. The function to apply of each element of `loader`.
  147. params (`dict`):
  148. The parameters passed to `infer` along with every item
  149. """
  150. super().__init__(loader, infer, params)
  151. def __iter__(self):
  152. self.iterator = iter(self.loader)
  153. self.subiterator = None
  154. return self
  155. def __next__(self):
  156. if self.subiterator is None:
  157. "Subiterator None means we haven't started a `preprocess` iterator. so start it"
  158. self.subiterator = self.infer(next(self.iterator), **self.params)
  159. try:
  160. # Try to return next item
  161. processed = next(self.subiterator)
  162. except StopIteration:
  163. # When a preprocess iterator ends, we can start lookig at the next item
  164. # ChunkIterator will keep feeding until ALL elements of iterator
  165. # all have created their subiterator and have been iterating against.
  166. #
  167. # Another way to look at it, is we're basically flattening lists of lists
  168. # into a single list, but with generators
  169. self.subiterator = self.infer(next(self.iterator), **self.params)
  170. processed = next(self.subiterator)
  171. return processed
  172. class PipelinePackIterator(PipelineIterator):
  173. """
  174. Roughly equivalent to
  175. ```
  176. packed = []
  177. for item in loader:
  178. packed.append(item)
  179. if item["is_last"]:
  180. yield packed
  181. packed = []
  182. ```
  183. but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In
  184. that case it does
  185. ```
  186. packed = []
  187. for batch in loader:
  188. # item is batched
  189. for item in batch:
  190. packed.append(item)
  191. if item["is_last"]:
  192. yield packed
  193. packed = []
  194. ```
  195. Arguments:
  196. loader (`torch.utils.data.DataLoader` or `Iterable`):
  197. The iterator that will be used to apply `infer` on.
  198. infer (any function):
  199. The function to apply of each element of `loader`.
  200. params (`dict`):
  201. The parameters passed to `infer` along with every item
  202. loader_batch_size (`int`, *optional*):
  203. If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making
  204. it roughly behave as
  205. ```
  206. for items in loader:
  207. for i in loader_batch_size:
  208. item = items[i]
  209. yield infer(item, **params)
  210. ```"""
  211. def __iter__(self):
  212. self.iterator = iter(self.loader)
  213. return self
  214. def __next__(self):
  215. # Extremely similar to PipelineIterator in its unpacking mechanism
  216. # BUT, we have an extra required item which is the presence of `is_last`
  217. # That is because everything is flattened by `PipelineChunkIterator` we
  218. # need to keep track of how to regroup here in the original `process`
  219. # boundaries so that `process` and `postprocess` see the same data.
  220. # This iterator accumulates items (possibly while unbatching) until it
  221. # its a `is_last` and then just passes it on to the caller.
  222. is_last = False
  223. accumulator = []
  224. if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
  225. while self._loader_batch_index < self.loader_batch_size:
  226. item = self.loader_batch_item()
  227. is_last = item.pop("is_last")
  228. accumulator.append(item)
  229. if is_last:
  230. return accumulator
  231. while not is_last:
  232. processed = self.infer(next(self.iterator), **self.params)
  233. if self.loader_batch_size is not None:
  234. if isinstance(processed, torch.Tensor):
  235. first_tensor = processed
  236. else:
  237. key = list(processed.keys())[0]
  238. first_tensor = processed[key]
  239. if isinstance(first_tensor, list):
  240. observed_batch_size = len(first_tensor)
  241. else:
  242. observed_batch_size = first_tensor.shape[0]
  243. if 0 < observed_batch_size < self.loader_batch_size:
  244. # could be last batch so we can't unroll as many
  245. # elements.
  246. self.loader_batch_size = observed_batch_size
  247. self._loader_batch_data = processed
  248. self._loader_batch_index = 0
  249. while self._loader_batch_index < self.loader_batch_size:
  250. item = self.loader_batch_item()
  251. is_last = item.pop("is_last")
  252. accumulator.append(item)
  253. if is_last:
  254. return accumulator
  255. else:
  256. item = processed
  257. is_last = item.pop("is_last")
  258. accumulator.append(item)
  259. return accumulator
  260. class KeyDataset(Dataset):
  261. def __init__(self, dataset: Dataset, key: str):
  262. self.dataset = dataset
  263. self.key = key
  264. def __len__(self):
  265. return len(self.dataset)
  266. def __getitem__(self, i):
  267. return self.dataset[i][self.key]
  268. class KeyPairDataset(Dataset):
  269. def __init__(self, dataset: Dataset, key1: str, key2: str):
  270. self.dataset = dataset
  271. self.key1 = key1
  272. self.key2 = key2
  273. def __len__(self):
  274. return len(self.dataset)
  275. def __getitem__(self, i):
  276. return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}