combining.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # mypy: allow-untyped-defs
  2. from torch.utils.data.datapipes._decorator import functional_datapipe
  3. from torch.utils.data.datapipes.datapipe import MapDataPipe
  4. from typing import Sized, Tuple, TypeVar
  5. __all__ = ["ConcaterMapDataPipe", "ZipperMapDataPipe"]
  6. T_co = TypeVar('T_co', covariant=True)
  7. @functional_datapipe('concat')
  8. class ConcaterMapDataPipe(MapDataPipe):
  9. r"""
  10. Concatenate multiple Map DataPipes (functional name: ``concat``).
  11. The new index of is the cumulative sum of source DataPipes.
  12. For example, if there are 2 source DataPipes both with length 5,
  13. index 0 to 4 of the resulting `ConcatMapDataPipe` would refer to
  14. elements of the first DataPipe, and 5 to 9 would refer to elements
  15. of the second DataPipe.
  16. Args:
  17. datapipes: Map DataPipes being concatenated
  18. Example:
  19. >>> # xdoctest: +SKIP
  20. >>> from torchdata.datapipes.map import SequenceWrapper
  21. >>> dp1 = SequenceWrapper(range(3))
  22. >>> dp2 = SequenceWrapper(range(3))
  23. >>> concat_dp = dp1.concat(dp2)
  24. >>> list(concat_dp)
  25. [0, 1, 2, 0, 1, 2]
  26. """
  27. datapipes: Tuple[MapDataPipe]
  28. def __init__(self, *datapipes: MapDataPipe):
  29. if len(datapipes) == 0:
  30. raise ValueError("Expected at least one DataPipe, but got nothing")
  31. if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
  32. raise TypeError("Expected all inputs to be `MapDataPipe`")
  33. if not all(isinstance(dp, Sized) for dp in datapipes):
  34. raise TypeError("Expected all inputs to be `Sized`")
  35. self.datapipes = datapipes # type: ignore[assignment]
  36. def __getitem__(self, index) -> T_co: # type: ignore[type-var]
  37. offset = 0
  38. for dp in self.datapipes:
  39. if index - offset < len(dp):
  40. return dp[index - offset]
  41. else:
  42. offset += len(dp)
  43. raise IndexError(f"Index {index} is out of range.")
  44. def __len__(self) -> int:
  45. return sum(len(dp) for dp in self.datapipes)
  46. @functional_datapipe('zip')
  47. class ZipperMapDataPipe(MapDataPipe[Tuple[T_co, ...]]):
  48. r"""
  49. Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
  50. This MataPipe is out of bound as soon as the shortest input DataPipe is exhausted.
  51. Args:
  52. *datapipes: Map DataPipes being aggregated
  53. Example:
  54. >>> # xdoctest: +SKIP
  55. >>> from torchdata.datapipes.map import SequenceWrapper
  56. >>> dp1 = SequenceWrapper(range(3))
  57. >>> dp2 = SequenceWrapper(range(10, 13))
  58. >>> zip_dp = dp1.zip(dp2)
  59. >>> list(zip_dp)
  60. [(0, 10), (1, 11), (2, 12)]
  61. """
  62. datapipes: Tuple[MapDataPipe[T_co], ...]
  63. def __init__(self, *datapipes: MapDataPipe[T_co]) -> None:
  64. if len(datapipes) == 0:
  65. raise ValueError("Expected at least one DataPipe, but got nothing")
  66. if not all(isinstance(dp, MapDataPipe) for dp in datapipes):
  67. raise TypeError("Expected all inputs to be `MapDataPipe`")
  68. if not all(isinstance(dp, Sized) for dp in datapipes):
  69. raise TypeError("Expected all inputs to be `Sized`")
  70. self.datapipes = datapipes
  71. def __getitem__(self, index) -> Tuple[T_co, ...]:
  72. res = []
  73. for dp in self.datapipes:
  74. try:
  75. res.append(dp[index])
  76. except IndexError as e:
  77. raise IndexError(f"Index {index} is out of range for one of the input MapDataPipes {dp}.") from e
  78. return tuple(res)
  79. def __len__(self) -> int:
  80. return min(len(dp) for dp in self.datapipes)