graph_settings.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import warnings
  4. from typing import Any, List, Optional, Set
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch.utils.data.datapipes.iter.sharding import (
  8. _ShardingIterDataPipe,
  9. SHARDING_PRIORITIES,
  10. )
  11. from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
  12. __all__ = [
  13. "apply_random_seed",
  14. "apply_sharding",
  15. "apply_shuffle_seed",
  16. "apply_shuffle_settings",
  17. "get_all_graph_pipes",
  18. ]
  19. def get_all_graph_pipes(graph: DataPipeGraph) -> List[DataPipe]:
  20. return _get_all_graph_pipes_helper(graph, set())
  21. def _get_all_graph_pipes_helper(graph: DataPipeGraph, id_cache: Set[int]) -> List[DataPipe]:
  22. results: List[DataPipe] = []
  23. for dp_id, (datapipe, sub_graph) in graph.items():
  24. if dp_id in id_cache:
  25. continue
  26. id_cache.add(dp_id)
  27. results.append(datapipe)
  28. results.extend(_get_all_graph_pipes_helper(sub_graph, id_cache))
  29. return results
  30. def _is_sharding_datapipe(datapipe: DataPipe) -> bool:
  31. if isinstance(datapipe, _ShardingIterDataPipe):
  32. return True
  33. if hasattr(datapipe, "apply_sharding") and inspect.ismethod(datapipe.apply_sharding):
  34. return True
  35. return False
  36. def apply_sharding(datapipe: DataPipe,
  37. num_of_instances: int,
  38. instance_id: int,
  39. sharding_group=SHARDING_PRIORITIES.DEFAULT) -> DataPipe:
  40. r"""
  41. Apply dynamic sharding over the ``sharding_filter`` DataPipe that has a method ``apply_sharding``.
  42. RuntimeError will be raised when multiple ``sharding_filter`` are presented in the same branch.
  43. """
  44. graph = traverse_dps(datapipe)
  45. def _helper(graph, prev_applied=None):
  46. for (dp, sub_graph) in graph.values():
  47. applied = None
  48. if _is_sharding_datapipe(dp):
  49. if prev_applied is not None:
  50. raise RuntimeError("Sharding twice on a single pipeline is likely unintended and will cause data loss. "
  51. f"Sharding already applied to {prev_applied} while trying to apply to {dp}")
  52. # For BC, only provide sharding_group if accepted
  53. sig = inspect.signature(dp.apply_sharding)
  54. if len(sig.parameters) < 3:
  55. dp.apply_sharding(num_of_instances, instance_id)
  56. else:
  57. dp.apply_sharding(num_of_instances, instance_id, sharding_group=sharding_group)
  58. applied = dp
  59. if applied is None:
  60. applied = prev_applied
  61. _helper(sub_graph, applied)
  62. _helper(graph)
  63. return datapipe
  64. def _is_shuffle_datapipe(datapipe: DataPipe) -> bool:
  65. if not hasattr(datapipe, "set_shuffle") or not hasattr(datapipe, "set_seed"):
  66. return False
  67. if not inspect.ismethod(datapipe.set_shuffle) or not inspect.ismethod(datapipe.set_seed):
  68. return False
  69. return True
  70. def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
  71. r"""
  72. Traverse the graph of ``DataPipes`` to find and set shuffle attribute.
  73. Apply the method to each `DataPipe` that has APIs of ``set_shuffle``
  74. and ``set_seed``.
  75. Args:
  76. datapipe: DataPipe that needs to set shuffle attribute
  77. shuffle: Shuffle option (default: ``None`` and no-op to the graph)
  78. """
  79. if shuffle is None:
  80. return datapipe
  81. graph = traverse_dps(datapipe)
  82. all_pipes = get_all_graph_pipes(graph)
  83. shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
  84. if not shufflers and shuffle:
  85. warnings.warn(
  86. "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
  87. "Be aware that the default buffer size might not be sufficient for your task."
  88. )
  89. datapipe = datapipe.shuffle()
  90. shufflers = [datapipe, ] # type: ignore[list-item]
  91. for shuffler in shufflers:
  92. shuffler.set_shuffle(shuffle)
  93. return datapipe
  94. @deprecated(
  95. "`apply_shuffle_seed` is deprecated since 1.12 and will be removed in the future releases. "
  96. "Please use `apply_random_seed` instead.",
  97. category=FutureWarning,
  98. )
  99. def apply_shuffle_seed(datapipe: DataPipe, rng: Any) -> DataPipe:
  100. return apply_random_seed(datapipe, rng)
  101. def _is_random_datapipe(datapipe: DataPipe) -> bool:
  102. if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed):
  103. return True
  104. return False
  105. def apply_random_seed(datapipe: DataPipe, rng: torch.Generator) -> DataPipe:
  106. r"""
  107. Traverse the graph of ``DataPipes`` to find random ``DataPipe`` with an API of ``set_seed``.
  108. Then set the random seed based on the provided RNG to those ``DataPipe``.
  109. Args:
  110. datapipe: DataPipe that needs to set randomness
  111. rng: Random number generator to generate random seeds
  112. """
  113. graph = traverse_dps(datapipe)
  114. all_pipes = get_all_graph_pipes(graph)
  115. # Using a set to track id of DataPipe to prevent setting randomness per DataPipe more than once.
  116. # And, `id` is used in case of unhashable DataPipe
  117. cache = set()
  118. random_datapipes = []
  119. for pipe in all_pipes:
  120. if id(pipe) in cache:
  121. continue
  122. if _is_random_datapipe(pipe):
  123. random_datapipes.append(pipe)
  124. cache.add(id(pipe))
  125. for pipe in random_datapipes:
  126. random_seed = int(torch.empty((), dtype=torch.int64).random_(generator=rng).item())
  127. pipe.set_seed(random_seed)
  128. return datapipe