graph.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # mypy: allow-untyped-defs
  2. import io
  3. import pickle
  4. import warnings
  5. from collections.abc import Collection
  6. from typing import Dict, List, Optional, Set, Tuple, Type, Union
  7. from torch.utils.data import IterDataPipe, MapDataPipe
  8. from torch.utils._import_utils import dill_available
  9. __all__ = ["traverse", "traverse_dps"]
  10. DataPipe = Union[IterDataPipe, MapDataPipe]
  11. DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc]
  12. def _stub_unpickler():
  13. return "STUB"
  14. # TODO(VitalyFedyunin): Make sure it works without dill module installed
  15. def _list_connected_datapipes(scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]) -> List[DataPipe]:
  16. f = io.BytesIO()
  17. p = pickle.Pickler(f) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
  18. if dill_available():
  19. from dill import Pickler as dill_Pickler
  20. d = dill_Pickler(f)
  21. else:
  22. d = None
  23. captured_connections = []
  24. def getstate_hook(ori_state):
  25. state = None
  26. if isinstance(ori_state, dict):
  27. state = {} # type: ignore[assignment]
  28. for k, v in ori_state.items():
  29. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  30. state[k] = v # type: ignore[attr-defined]
  31. elif isinstance(ori_state, (tuple, list)):
  32. state = [] # type: ignore[assignment]
  33. for v in ori_state:
  34. if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
  35. state.append(v) # type: ignore[attr-defined]
  36. elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
  37. state = ori_state # type: ignore[assignment]
  38. return state
  39. def reduce_hook(obj):
  40. if obj == scan_obj or id(obj) in cache:
  41. raise NotImplementedError
  42. else:
  43. captured_connections.append(obj)
  44. # Adding id to remove duplicate DataPipe serialized at the same level
  45. cache.add(id(obj))
  46. return _stub_unpickler, ()
  47. datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment]
  48. try:
  49. for cls in datapipe_classes:
  50. cls.set_reduce_ex_hook(reduce_hook)
  51. if only_datapipe:
  52. cls.set_getstate_hook(getstate_hook)
  53. try:
  54. p.dump(scan_obj)
  55. except (pickle.PickleError, AttributeError, TypeError):
  56. if dill_available():
  57. d.dump(scan_obj)
  58. else:
  59. raise
  60. finally:
  61. for cls in datapipe_classes:
  62. cls.set_reduce_ex_hook(None)
  63. if only_datapipe:
  64. cls.set_getstate_hook(None)
  65. if dill_available():
  66. from dill import extend as dill_extend
  67. dill_extend(False) # Undo change to dispatch table
  68. return captured_connections
  69. def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
  70. r"""
  71. Traverse the DataPipes and their attributes to extract the DataPipe graph.
  72. This only looks into the attribute from each DataPipe that is either a
  73. DataPipe and a Python collection object such as ``list``, ``tuple``,
  74. ``set`` and ``dict``.
  75. Args:
  76. datapipe: the end DataPipe of the graph
  77. Returns:
  78. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  79. and values are tuples of DataPipe instance and the sub-graph
  80. """
  81. cache: Set[int] = set()
  82. return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  83. def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
  84. r"""
  85. Traverse the DataPipes and their attributes to extract the DataPipe graph.
  86. [Deprecated]
  87. When ``only_dataPipe`` is specified as ``True``, it would only look into the
  88. attribute from each DataPipe that is either a DataPipe and a Python collection object
  89. such as ``list``, ``tuple``, ``set`` and ``dict``.
  90. Note:
  91. This function is deprecated. Please use `traverse_dps` instead.
  92. Args:
  93. datapipe: the end DataPipe of the graph
  94. only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
  95. This argument is deprecating and will be removed after the next release.
  96. Returns:
  97. A graph represented as a nested dictionary, where keys are ids of DataPipe instances
  98. and values are tuples of DataPipe instance and the sub-graph
  99. """
  100. msg = "`traverse` function and will be removed after 1.13. " \
  101. "Please use `traverse_dps` instead."
  102. if not only_datapipe:
  103. msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
  104. warnings.warn(msg, FutureWarning)
  105. if only_datapipe is None:
  106. only_datapipe = False
  107. cache: Set[int] = set()
  108. return _traverse_helper(datapipe, only_datapipe, cache)
  109. # Add cache here to prevent infinite recursion on DataPipe
  110. def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
  111. if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
  112. raise RuntimeError(f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found")
  113. dp_id = id(datapipe)
  114. if dp_id in cache:
  115. return {}
  116. cache.add(dp_id)
  117. # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
  118. items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  119. d: DataPipeGraph = {dp_id: (datapipe, {})}
  120. for item in items:
  121. # Using cache.copy() here is to prevent recursion on a single path rather than global graph
  122. # Single DataPipe can present multiple times in different paths in graph
  123. d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
  124. return d