metadata.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # mypy: allow-untyped-defs
  2. from dataclasses import dataclass
  3. from typing import List, Union, Optional
  4. from functools import reduce
  5. from torch.distributed.remote_device import _remote_device
  6. @dataclass
  7. class ShardMetadata:
  8. """
  9. Represents a shard of the overall Tensor including its
  10. offsets, lengths and device placement.
  11. Args:
  12. shard_offsets(List[int]): Offsets in the original tensor indicating
  13. the start offsets for this shard. Should have the same rank as
  14. the original tensor.
  15. shard_sizes(List[int]): Integers indicating the size of each
  16. dimension for this shard. Should have the same rank as the
  17. original tensor.
  18. placement(:class:`torch.distributed._remote_device`):
  19. Specifies the placement of this shard.
  20. """
  21. __slots__ = ['shard_offsets', 'shard_sizes', 'placement']
  22. shard_offsets: List[int]
  23. shard_sizes: List[int]
  24. placement: Optional[_remote_device]
  25. def __init__(
  26. self,
  27. shard_offsets: List[int],
  28. shard_sizes: List[int],
  29. placement: Optional[Union[str, _remote_device]] = None
  30. ):
  31. self.shard_offsets = shard_offsets
  32. self.shard_sizes = shard_sizes
  33. if isinstance(placement, str):
  34. self.placement = _remote_device(placement)
  35. else:
  36. self.placement = placement
  37. if len(self.shard_offsets) != len(self.shard_sizes):
  38. raise ValueError(
  39. f'shard_offsets and shard_sizes should have '
  40. f'the same number of elements, found {len(self.shard_offsets)} '
  41. f'and {self.shard_sizes} respectively')
  42. for i in range(len(self.shard_offsets)):
  43. if self.shard_offsets[i] < 0:
  44. raise ValueError('shard_offsets should be >=0')
  45. if self.shard_sizes[i] < 0:
  46. raise ValueError('shard_sizes should be >= 0')
  47. def __hash__(self):
  48. def _hash_reduce(a, b):
  49. return (a << 8) + hash(b)
  50. res = reduce(_hash_reduce, self.shard_offsets, 37)
  51. res = reduce(_hash_reduce, self.shard_sizes, res)
  52. res = _hash_reduce(res, self.placement)
  53. return res