utils.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import warnings
  4. from torch.utils.data.datapipes.datapipe import MapDataPipe
  5. __all__ = ["SequenceWrapperMapDataPipe", ]
  6. class SequenceWrapperMapDataPipe(MapDataPipe):
  7. r"""
  8. Wraps a sequence object into a MapDataPipe.
  9. Args:
  10. sequence: Sequence object to be wrapped into an MapDataPipe
  11. deepcopy: Option to deepcopy input sequence object
  12. .. note::
  13. If ``deepcopy`` is set to False explicitly, users should ensure
  14. that data pipeline doesn't contain any in-place operations over
  15. the iterable instance, in order to prevent data inconsistency
  16. across iterations.
  17. Example:
  18. >>> # xdoctest: +SKIP
  19. >>> from torchdata.datapipes.map import SequenceWrapper
  20. >>> dp = SequenceWrapper(range(10))
  21. >>> list(dp)
  22. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  23. >>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
  24. >>> dp['a']
  25. 100
  26. """
  27. def __init__(self, sequence, deepcopy=True):
  28. if deepcopy:
  29. try:
  30. self.sequence = copy.deepcopy(sequence)
  31. except TypeError:
  32. warnings.warn(
  33. "The input sequence can not be deepcopied, "
  34. "please be aware of in-place modification would affect source data"
  35. )
  36. self.sequence = sequence
  37. else:
  38. self.sequence = sequence
  39. def __getitem__(self, index):
  40. return self.sequence[index]
  41. def __len__(self):
  42. return len(self.sequence)