callable.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # mypy: allow-untyped-defs
  2. from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
  3. from typing import Callable, TypeVar
  4. from torch.utils.data.datapipes._decorator import functional_datapipe
  5. from torch.utils.data.datapipes.datapipe import MapDataPipe
  6. __all__ = ["MapperMapDataPipe", "default_fn"]
  7. T_co = TypeVar('T_co', covariant=True)
  8. # Default function to return each item directly
  9. # In order to keep datapipe picklable, eliminates the usage
  10. # of python lambda function
  11. def default_fn(data):
  12. return data
  13. @functional_datapipe('map')
  14. class MapperMapDataPipe(MapDataPipe[T_co]):
  15. r"""
  16. Apply the input function over each item from the source DataPipe (functional name: ``map``).
  17. The function can be any regular Python function or partial object. Lambda
  18. function is not recommended as it is not supported by pickle.
  19. Args:
  20. datapipe: Source MapDataPipe
  21. fn: Function being applied to each item
  22. Example:
  23. >>> # xdoctest: +SKIP
  24. >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
  25. >>> def add_one(x):
  26. ... return x + 1
  27. >>> dp = SequenceWrapper(range(10))
  28. >>> map_dp_1 = dp.map(add_one)
  29. >>> list(map_dp_1)
  30. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  31. >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
  32. >>> list(map_dp_2)
  33. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  34. """
  35. datapipe: MapDataPipe
  36. fn: Callable
  37. def __init__(
  38. self,
  39. datapipe: MapDataPipe,
  40. fn: Callable = default_fn,
  41. ) -> None:
  42. super().__init__()
  43. self.datapipe = datapipe
  44. _check_unpickable_fn(fn)
  45. self.fn = fn # type: ignore[assignment]
  46. def __len__(self) -> int:
  47. return len(self.datapipe)
  48. def __getitem__(self, index) -> T_co:
  49. return self.fn(self.datapipe[index])