| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- import numpy as np
- import torch
- from torch.utils.data import Dataset, IterableDataset
- from ..utils.generic import ModelOutput
- class PipelineDataset(Dataset):
- def __init__(self, dataset, process, params):
- self.dataset = dataset
- self.process = process
- self.params = params
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, i):
- item = self.dataset[i]
- processed = self.process(item, **self.params)
- return processed
- class PipelineIterator(IterableDataset):
- def __init__(self, loader, infer, params, loader_batch_size=None):
- """
- Roughly equivalent to
- ```
- for item in loader:
- yield infer(item, **params)
- ```
- Arguments:
- loader (`torch.utils.data.DataLoader` or `Iterable`):
- The iterator that will be used to apply `infer` on.
- infer (any function):
- The function to apply of each element of `loader`.
- params (`dict`):
- The parameters passed to `infer` along with every item
- loader_batch_size (`int`, *optional*):
- If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
- making it roughly behave as
- ```
- for items in loader:
- for i in loader_batch_size:
- item = items[i]
- yield infer(item, **params)
- ```"""
- self.loader = loader
- self.infer = infer
- self.params = params
- if loader_batch_size == 1:
- # Let's spare some time by deactivating altogether
- loader_batch_size = None
- self.loader_batch_size = loader_batch_size
- # Internal bookkeeping
- self._loader_batch_index = None
- self._loader_batch_data = None
- def __len__(self):
- return len(self.loader)
- def __iter__(self):
- self.iterator = iter(self.loader)
- return self
- def loader_batch_item(self):
- """
- Return item located at `loader_batch_index` within the current `loader_batch_data`.
- """
- if isinstance(self._loader_batch_data, torch.Tensor):
- # Batch data is simple tensor, just fetch the slice
- result = self._loader_batch_data[self._loader_batch_index].unsqueeze(0)
- else:
- # Batch data is assumed to be BaseModelOutput (or dict)
- loader_batched = {}
- for k, element in self._loader_batch_data.items():
- if isinstance(element, ModelOutput):
- # Convert ModelOutput to tuple first
- element = element.to_tuple()
- if isinstance(element[0], torch.Tensor):
- loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
- elif isinstance(element[0], np.ndarray):
- loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
- continue
- if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
- # Those are stored as lists of tensors so need specific unbatching.
- if isinstance(element[0], torch.Tensor):
- loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
- elif isinstance(element[0], np.ndarray):
- loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
- continue
- if element is None:
- # This can happen for optional data that get passed around
- loader_batched[k] = None
- elif isinstance(element[self._loader_batch_index], torch.Tensor):
- # Take correct batch data, but make it looked like batch_size=1
- # For compatibility with other methods within transformers
- loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
- elif isinstance(element[self._loader_batch_index], np.ndarray):
- # Take correct batch data, but make it looked like batch_size=1
- # For compatibility with other methods within transformers
- loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
- else:
- # This is typically a list, so no need to `unsqueeze`.
- loader_batched[k] = element[self._loader_batch_index]
- # Recreate the element by reusing the original class to make it look
- # batch_size=1
- result = self._loader_batch_data.__class__(loader_batched)
- self._loader_batch_index += 1
- return result
- def __next__(self):
- if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
- # We are currently unrolling a batch so we just need to return
- # the current item within a batch
- return self.loader_batch_item()
- # We're out of items within a batch
- item = next(self.iterator)
- processed = self.infer(item, **self.params)
- # We now have a batch of "inferred things".
- if self.loader_batch_size is not None:
- # Try to infer the size of the batch
- if isinstance(processed, torch.Tensor):
- first_tensor = processed
- elif isinstance(processed, tuple):
- first_tensor = processed[0]
- else:
- key = list(processed.keys())[0]
- first_tensor = processed[key]
- if isinstance(first_tensor, list):
- observed_batch_size = len(first_tensor)
- else:
- observed_batch_size = first_tensor.shape[0]
- if 0 < observed_batch_size < self.loader_batch_size:
- # could be last batch so we can't unroll as many
- # elements.
- self.loader_batch_size = observed_batch_size
- # Setting internal index to unwrap the batch
- self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
- self._loader_batch_index = 0
- return self.loader_batch_item()
- else:
- # We're not unrolling batches
- return processed
- class PipelineChunkIterator(PipelineIterator):
- def __init__(self, loader, infer, params, loader_batch_size=None):
- """
- Roughly equivalent to
- ```
- for iterator in loader:
- for item in iterator:
- yield infer(item, **params)
- ```
- Arguments:
- loader (`torch.utils.data.DataLoader` or `Iterable`):
- The iterator that will be used to apply `infer` on.
- infer (any function):
- The function to apply of each element of `loader`.
- params (`dict`):
- The parameters passed to `infer` along with every item
- """
- super().__init__(loader, infer, params)
- def __iter__(self):
- self.iterator = iter(self.loader)
- self.subiterator = None
- return self
- def __next__(self):
- if self.subiterator is None:
- "Subiterator None means we haven't started a `preprocess` iterator. so start it"
- self.subiterator = self.infer(next(self.iterator), **self.params)
- try:
- # Try to return next item
- processed = next(self.subiterator)
- except StopIteration:
- # When a preprocess iterator ends, we can start lookig at the next item
- # ChunkIterator will keep feeding until ALL elements of iterator
- # all have created their subiterator and have been iterating against.
- #
- # Another way to look at it, is we're basically flattening lists of lists
- # into a single list, but with generators
- self.subiterator = self.infer(next(self.iterator), **self.params)
- processed = next(self.subiterator)
- return processed
- class PipelinePackIterator(PipelineIterator):
- """
- Roughly equivalent to
- ```
- packed = []
- for item in loader:
- packed.append(item)
- if item["is_last"]:
- yield packed
- packed = []
- ```
- but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In
- that case it does
- ```
- packed = []
- for batch in loader:
- # item is batched
- for item in batch:
- packed.append(item)
- if item["is_last"]:
- yield packed
- packed = []
- ```
- Arguments:
- loader (`torch.utils.data.DataLoader` or `Iterable`):
- The iterator that will be used to apply `infer` on.
- infer (any function):
- The function to apply of each element of `loader`.
- params (`dict`):
- The parameters passed to `infer` along with every item
- loader_batch_size (`int`, *optional*):
- If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making
- it roughly behave as
- ```
- for items in loader:
- for i in loader_batch_size:
- item = items[i]
- yield infer(item, **params)
- ```"""
- def __iter__(self):
- self.iterator = iter(self.loader)
- return self
- def __next__(self):
- # Extremely similar to PipelineIterator in its unpacking mechanism
- # BUT, we have an extra required item which is the presence of `is_last`
- # That is because everything is flattened by `PipelineChunkIterator` we
- # need to keep track of how to regroup here in the original `process`
- # boundaries so that `process` and `postprocess` see the same data.
- # This iterator accumulates items (possibly while unbatching) until it
- # its a `is_last` and then just passes it on to the caller.
- is_last = False
- accumulator = []
- if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
- while self._loader_batch_index < self.loader_batch_size:
- item = self.loader_batch_item()
- is_last = item.pop("is_last")
- accumulator.append(item)
- if is_last:
- return accumulator
- while not is_last:
- processed = self.infer(next(self.iterator), **self.params)
- if self.loader_batch_size is not None:
- if isinstance(processed, torch.Tensor):
- first_tensor = processed
- else:
- key = list(processed.keys())[0]
- first_tensor = processed[key]
- if isinstance(first_tensor, list):
- observed_batch_size = len(first_tensor)
- else:
- observed_batch_size = first_tensor.shape[0]
- if 0 < observed_batch_size < self.loader_batch_size:
- # could be last batch so we can't unroll as many
- # elements.
- self.loader_batch_size = observed_batch_size
- self._loader_batch_data = processed
- self._loader_batch_index = 0
- while self._loader_batch_index < self.loader_batch_size:
- item = self.loader_batch_item()
- is_last = item.pop("is_last")
- accumulator.append(item)
- if is_last:
- return accumulator
- else:
- item = processed
- is_last = item.pop("is_last")
- accumulator.append(item)
- return accumulator
- class KeyDataset(Dataset):
- def __init__(self, dataset: Dataset, key: str):
- self.dataset = dataset
- self.key = key
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, i):
- return self.dataset[i][self.key]
- class KeyPairDataset(Dataset):
- def __init__(self, dataset: Dataset, key1: str, key2: str):
- self.dataset = dataset
- self.key1 = key1
- self.key2 = key2
- def __len__(self):
- return len(self.dataset)
- def __getitem__(self, i):
- return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}
|