_utils.py 1.1 KB

12345678910111213141516171819202122232425262728
  1. import torch
  2. from torch.distributed._shard.metadata import ShardMetadata
  3. from typing import Sequence
  4. DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor."
  5. def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor:
  6. """
  7. Narrow the tensor according to ``offsets`` and ``sizes``.
  8. """
  9. narrowed_tensor = tensor
  10. for idx, (offset, size) in enumerate(zip(offsets, sizes)):
  11. if size < tensor.size(idx):
  12. # Reshape to get shard for this rank and we don't want autograd
  13. # recording here for the narrow op and 'local_shard' should be a
  14. # leaf variable in the autograd graph.
  15. narrowed_tensor = narrowed_tensor.narrow(
  16. idx,
  17. offset,
  18. size
  19. )
  20. return narrowed_tensor
  21. def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
  22. """
  23. Narrow the tensor according to the metadata
  24. """
  25. return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)