utils.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import warnings
  4. from torch.utils.data.datapipes.datapipe import IterDataPipe
  5. __all__ = ["IterableWrapperIterDataPipe", ]
  6. class IterableWrapperIterDataPipe(IterDataPipe):
  7. r"""
  8. Wraps an iterable object to create an IterDataPipe.
  9. Args:
  10. iterable: Iterable object to be wrapped into an IterDataPipe
  11. deepcopy: Option to deepcopy input iterable object for each
  12. iterator. The copy is made when the first element is read in ``iter()``.
  13. .. note::
  14. If ``deepcopy`` is explicitly set to ``False``, users should ensure
  15. that the data pipeline doesn't contain any in-place operations over
  16. the iterable instance to prevent data inconsistency across iterations.
  17. Example:
  18. >>> # xdoctest: +SKIP
  19. >>> from torchdata.datapipes.iter import IterableWrapper
  20. >>> dp = IterableWrapper(range(10))
  21. >>> list(dp)
  22. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  23. """
  24. def __init__(self, iterable, deepcopy=True):
  25. self.iterable = iterable
  26. self.deepcopy = deepcopy
  27. def __iter__(self):
  28. source_data = self.iterable
  29. if self.deepcopy:
  30. try:
  31. source_data = copy.deepcopy(self.iterable)
  32. # For the case that data cannot be deep-copied,
  33. # all in-place operations will affect iterable variable.
  34. # When this DataPipe is iterated second time, it will
  35. # yield modified items.
  36. except TypeError:
  37. warnings.warn(
  38. "The input iterable can not be deepcopied, "
  39. "please be aware of in-place modification would affect source data."
  40. )
  41. yield from source_data
  42. def __len__(self):
  43. return len(self.iterable)