sharding.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # mypy: allow-untyped-defs
  2. from typing import (
  3. Dict,
  4. Sized,
  5. Tuple,
  6. )
  7. from torch.utils.data.datapipes._decorator import functional_datapipe
  8. from torch.utils.data.datapipes.datapipe import IterDataPipe
  9. from enum import IntEnum
  10. __all__ = [
  11. "SHARDING_PRIORITIES",
  12. "ShardingFilterIterDataPipe",
  13. ]
  14. class SHARDING_PRIORITIES(IntEnum):
  15. DEFAULT = 1
  16. DISTRIBUTED = 2
  17. MULTIPROCESSING = 3
  18. class _ShardingIterDataPipe(IterDataPipe):
  19. def apply_sharding(self, num_of_instances: int, instance_id: int, sharding_group: SHARDING_PRIORITIES):
  20. raise NotImplementedError
  21. @functional_datapipe('sharding_filter')
  22. class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
  23. r"""
  24. Wrapper that allows DataPipe to be sharded (functional name: ``sharding_filter``).
  25. After ``apply_sharding`` is called, each instance of the DataPipe (on different workers) will have every `n`-th element of the
  26. original DataPipe, where `n` equals to the number of instances.
  27. Args:
  28. source_datapipe: Iterable DataPipe that will be sharded
  29. """
  30. def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None):
  31. self.source_datapipe = source_datapipe
  32. self.sharding_group_filter = sharding_group_filter
  33. self.groups: Dict[int, Tuple[int, int]] = {}
  34. self.num_of_instances = 1
  35. self.instance_id = 0
  36. self._update_num_of_instances()
  37. def apply_sharding(self, num_of_instances, instance_id, sharding_group=SHARDING_PRIORITIES.DEFAULT):
  38. if instance_id >= num_of_instances:
  39. raise ValueError(f"instance_id({instance_id}) should be smaller than num_of_instances({num_of_instances})")
  40. if sharding_group == SHARDING_PRIORITIES.DEFAULT:
  41. if len(self.groups) and SHARDING_PRIORITIES.DEFAULT not in self.groups:
  42. raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups') # noqa: TRY002
  43. else:
  44. if SHARDING_PRIORITIES.DEFAULT in self.groups:
  45. raise Exception('ShardingFilter cannot mix DEFAULT and non DEFAULT groups') # noqa: TRY002
  46. self.groups[sharding_group] = (num_of_instances, instance_id)
  47. self._update_num_of_instances()
  48. def _update_num_of_instances(self):
  49. sorted_sharding_groups = []
  50. for key in sorted(self.groups.keys()):
  51. if self.sharding_group_filter is None or key == self.sharding_group_filter:
  52. sorted_sharding_groups.append(self.groups[key])
  53. sorted_sharding_groups.reverse()
  54. self.num_of_instances = 1
  55. self.instance_id = 0
  56. for group_num_of_instances, group_instance_id in sorted_sharding_groups:
  57. self.instance_id += self.num_of_instances * group_instance_id
  58. self.num_of_instances *= group_num_of_instances
  59. def __iter__(self):
  60. for i, item in enumerate(self.source_datapipe):
  61. if i % self.num_of_instances == self.instance_id:
  62. yield item
  63. def __len__(self):
  64. if isinstance(self.source_datapipe, Sized):
  65. return len(self.source_datapipe) // self.num_of_instances +\
  66. (1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
  67. raise TypeError(f"{type(self).__name__} instance doesn't have valid length")