selecting.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # mypy: allow-untyped-defs
  2. from typing import Callable, Iterator, Tuple, TypeVar
  3. from torch.utils.data.datapipes._decorator import functional_datapipe
  4. from torch.utils.data.datapipes.datapipe import IterDataPipe
  5. from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
  6. from torch.utils.data.datapipes.utils.common import (
  7. _check_unpickable_fn,
  8. StreamWrapper,
  9. validate_input_col
  10. )
  11. __all__ = ["FilterIterDataPipe", ]
  12. T = TypeVar('T')
  13. T_co = TypeVar('T_co', covariant=True)
  14. @functional_datapipe('filter')
  15. class FilterIterDataPipe(IterDataPipe[T_co]):
  16. r"""
  17. Filters out elements from the source datapipe according to input ``filter_fn`` (functional name: ``filter``).
  18. Args:
  19. datapipe: Iterable DataPipe being filtered
  20. filter_fn: Customized function mapping an element to a boolean.
  21. input_col: Index or indices of data which ``filter_fn`` is applied, such as:
  22. - ``None`` as default to apply ``filter_fn`` to the data directly.
  23. - Integer(s) is used for list/tuple.
  24. - Key(s) is used for dict.
  25. Example:
  26. >>> # xdoctest: +SKIP
  27. >>> from torchdata.datapipes.iter import IterableWrapper
  28. >>> def is_even(n):
  29. ... return n % 2 == 0
  30. >>> dp = IterableWrapper(range(5))
  31. >>> filter_dp = dp.filter(filter_fn=is_even)
  32. >>> list(filter_dp)
  33. [0, 2, 4]
  34. """
  35. datapipe: IterDataPipe[T_co]
  36. filter_fn: Callable
  37. def __init__(
  38. self,
  39. datapipe: IterDataPipe[T_co],
  40. filter_fn: Callable,
  41. input_col=None,
  42. ) -> None:
  43. super().__init__()
  44. self.datapipe = datapipe
  45. _check_unpickable_fn(filter_fn)
  46. self.filter_fn = filter_fn # type: ignore[assignment]
  47. self.input_col = input_col
  48. validate_input_col(filter_fn, input_col)
  49. def _apply_filter_fn(self, data) -> bool:
  50. if self.input_col is None:
  51. return self.filter_fn(data)
  52. elif isinstance(self.input_col, (list, tuple)):
  53. args = tuple(data[col] for col in self.input_col)
  54. return self.filter_fn(*args)
  55. else:
  56. return self.filter_fn(data[self.input_col])
  57. def __iter__(self) -> Iterator[T_co]:
  58. for data in self.datapipe:
  59. condition, filtered = self._returnIfTrue(data)
  60. if condition:
  61. yield filtered
  62. else:
  63. StreamWrapper.close_streams(data)
  64. def _returnIfTrue(self, data: T) -> Tuple[bool, T]:
  65. condition = self._apply_filter_fn(data)
  66. if df_wrapper.is_column(condition):
  67. # We are operating on DataFrames filter here
  68. result = []
  69. for idx, mask in enumerate(df_wrapper.iterate(condition)):
  70. if mask:
  71. result.append(df_wrapper.get_item(data, idx))
  72. if len(result):
  73. return True, df_wrapper.concat(result)
  74. else:
  75. return False, None # type: ignore[return-value]
  76. if not isinstance(condition, bool):
  77. raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition))
  78. return condition, data