callable.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from collections import namedtuple
  4. from typing import Callable, Iterator, Sized, TypeVar, Optional, Union, Any, Dict, List
  5. from torch.utils.data.datapipes._decorator import functional_datapipe
  6. from torch.utils.data._utils.collate import default_collate
  7. from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
  8. from torch.utils.data.datapipes.datapipe import IterDataPipe
  9. from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn,
  10. validate_input_col)
  11. __all__ = [
  12. "CollatorIterDataPipe",
  13. "MapperIterDataPipe",
  14. ]
  15. T_co = TypeVar("T_co", covariant=True)
  16. @functional_datapipe("map")
  17. class MapperIterDataPipe(IterDataPipe[T_co]):
  18. r"""
  19. Applies a function over each item from the source DataPipe (functional name: ``map``).
  20. The function can be any regular Python function or partial object. Lambda
  21. function is not recommended as it is not supported by pickle.
  22. Args:
  23. datapipe: Source Iterable DataPipe
  24. fn: Function being applied over each item
  25. input_col: Index or indices of data which ``fn`` is applied, such as:
  26. - ``None`` as default to apply ``fn`` to the data directly.
  27. - Integer(s) is used for list/tuple.
  28. - Key(s) is used for dict.
  29. output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
  30. only when ``input_col`` is not ``None``
  31. - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
  32. multiple indices, the left-most one is used, and other indices will be removed.
  33. - Integer is used for list/tuple. ``-1`` represents to append result at the end.
  34. - Key is used for dict. New key is acceptable.
  35. Example:
  36. >>> # xdoctest: +SKIP
  37. >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
  38. >>> def add_one(x):
  39. ... return x + 1
  40. >>> dp = IterableWrapper(range(10))
  41. >>> map_dp_1 = dp.map(add_one) # Invocation via functional form is preferred
  42. >>> list(map_dp_1)
  43. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  44. >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
  45. >>> # Use `functools.partial` or explicitly define the function instead
  46. >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
  47. >>> list(map_dp_2)
  48. [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  49. """
  50. datapipe: IterDataPipe
  51. fn: Callable
  52. def __init__(
  53. self,
  54. datapipe: IterDataPipe,
  55. fn: Callable,
  56. input_col=None,
  57. output_col=None,
  58. ) -> None:
  59. super().__init__()
  60. self.datapipe = datapipe
  61. _check_unpickable_fn(fn)
  62. self.fn = fn # type: ignore[assignment]
  63. self.input_col = input_col
  64. if input_col is None and output_col is not None:
  65. raise ValueError("`output_col` must be None when `input_col` is None.")
  66. if isinstance(output_col, (list, tuple)):
  67. if len(output_col) > 1:
  68. raise ValueError("`output_col` must be a single-element list or tuple")
  69. output_col = output_col[0]
  70. self.output_col = output_col
  71. validate_input_col(fn, input_col)
  72. def _apply_fn(self, data):
  73. if self.input_col is None and self.output_col is None:
  74. return self.fn(data)
  75. if self.input_col is None:
  76. res = self.fn(data)
  77. elif isinstance(self.input_col, (list, tuple)):
  78. args = tuple(data[col] for col in self.input_col)
  79. res = self.fn(*args)
  80. else:
  81. res = self.fn(data[self.input_col])
  82. # Copy tuple to list and run in-place modification because tuple is immutable.
  83. if isinstance(data, tuple):
  84. t_flag = True
  85. data = list(data)
  86. else:
  87. t_flag = False
  88. if self.output_col is None:
  89. if isinstance(self.input_col, (list, tuple)):
  90. data[self.input_col[0]] = res
  91. for idx in sorted(self.input_col[1:], reverse=True):
  92. del data[idx]
  93. else:
  94. data[self.input_col] = res
  95. else:
  96. if self.output_col == -1:
  97. data.append(res)
  98. else:
  99. data[self.output_col] = res
  100. # Convert list back to tuple
  101. return tuple(data) if t_flag else data
  102. def __iter__(self) -> Iterator[T_co]:
  103. for data in self.datapipe:
  104. yield self._apply_fn(data)
  105. def __len__(self) -> int:
  106. if isinstance(self.datapipe, Sized):
  107. return len(self.datapipe)
  108. raise TypeError(
  109. f"{type(self).__name__} instance doesn't have valid length"
  110. )
  111. def _collate_helper(conversion, item):
  112. # TODO(VitalyFedyunin): Verify that item is any sort of batch
  113. if len(item.items) > 1:
  114. # TODO(VitalyFedyunin): Compact all batch dataframes into one
  115. raise Exception("Only supports one DataFrame per batch") # noqa: TRY002
  116. df = item[0]
  117. columns_name = df_wrapper.get_columns(df)
  118. tuple_names: List = []
  119. tuple_values: List = []
  120. for name in conversion.keys():
  121. if name not in columns_name:
  122. raise Exception("Conversion keys missmatch") # noqa: TRY002
  123. for name in columns_name:
  124. if name in conversion:
  125. if not callable(conversion[name]):
  126. raise Exception('Collate (DF)DataPipe requires callable as dict values') # noqa: TRY002
  127. collation_fn = conversion[name]
  128. else:
  129. # TODO(VitalyFedyunin): Add default collation into df_wrapper
  130. try:
  131. import torcharrow.pytorch as tap # type: ignore[import]
  132. collation_fn = tap.rec.Default()
  133. except Exception as e:
  134. raise Exception("unable to import default collation function from the TorchArrow") from e # noqa: TRY002
  135. tuple_names.append(str(name))
  136. value = collation_fn(df[name])
  137. tuple_values.append(value)
  138. # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
  139. # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
  140. tpl_cls = namedtuple("CollateResult", tuple_names) # type: ignore[misc]
  141. tuple = tpl_cls(*tuple_values)
  142. return tuple
  143. @functional_datapipe("collate")
  144. class CollatorIterDataPipe(MapperIterDataPipe):
  145. r"""
  146. Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
  147. By default, it uses :func:`torch.utils.data.default_collate`.
  148. .. note::
  149. While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
  150. default behavior and `functools.partial` to specify any additional arguments.
  151. Args:
  152. datapipe: Iterable DataPipe being collated
  153. collate_fn: Customized collate function to collect and combine data or a batch of data.
  154. Default function collates to Tensor(s) based on data type.
  155. Example:
  156. >>> # xdoctest: +SKIP
  157. >>> # Convert integer data to float Tensor
  158. >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
  159. ... def __init__(self, start, end):
  160. ... super(MyIterDataPipe).__init__()
  161. ... assert end > start, "this example code only works with end >= start"
  162. ... self.start = start
  163. ... self.end = end
  164. ...
  165. ... def __iter__(self):
  166. ... return iter(range(self.start, self.end))
  167. ...
  168. ... def __len__(self):
  169. ... return self.end - self.start
  170. ...
  171. >>> ds = MyIterDataPipe(start=3, end=7)
  172. >>> print(list(ds))
  173. [3, 4, 5, 6]
  174. >>> def collate_fn(batch):
  175. ... return torch.tensor(batch, dtype=torch.float)
  176. ...
  177. >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
  178. >>> print(list(collated_ds))
  179. [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
  180. """
  181. def __init__(
  182. self,
  183. datapipe: IterDataPipe,
  184. conversion: Optional[
  185. Union[
  186. Callable[..., Any],
  187. Dict[Union[str, Any], Union[Callable, Any]],
  188. ]
  189. ] = default_collate,
  190. collate_fn: Optional[Callable] = None,
  191. ) -> None:
  192. # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
  193. # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
  194. if collate_fn is not None:
  195. super().__init__(datapipe, fn=collate_fn)
  196. else:
  197. if callable(conversion):
  198. super().__init__(datapipe, fn=conversion)
  199. else:
  200. # TODO(VitalyFedyunin): Validate passed dictionary
  201. collate_fn = functools.partial(_collate_helper, conversion)
  202. super().__init__(datapipe, fn=collate_fn)