streamreader.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # mypy: allow-untyped-defs
  2. from typing import Tuple
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import IterDataPipe
  5. __all__ = ["StreamReaderIterDataPipe", ]
  6. @functional_datapipe('read_from_stream')
  7. class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
  8. r"""
  9. Given IO streams and their label names, yield bytes with label name as tuple.
  10. (functional name: ``read_from_stream``).
  11. Args:
  12. datapipe: Iterable DataPipe provides label/URL and byte stream
  13. chunk: Number of bytes to be read from stream per iteration.
  14. If ``None``, all bytes will be read until the EOF.
  15. Example:
  16. >>> # xdoctest: +SKIP
  17. >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
  18. >>> from io import StringIO
  19. >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
  20. >>> list(StreamReader(dp, chunk=1))
  21. [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
  22. """
  23. def __init__(self, datapipe, chunk=None):
  24. self.datapipe = datapipe
  25. self.chunk = chunk
  26. def __iter__(self):
  27. for furl, stream in self.datapipe:
  28. while True:
  29. d = stream.read(self.chunk)
  30. if not d:
  31. stream.close()
  32. break
  33. yield (furl, d)