_fsdp_extensions.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from abc import ABC, abstractmethod
  2. from typing import Any, List, Optional, Tuple
  3. import torch
  4. import torch.distributed as dist
  5. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  6. from torch.distributed._shard.sharded_tensor.shard import Shard
  7. from torch.distributed._tensor import DeviceMesh, DTensor
  8. from torch.distributed.fsdp._shard_utils import (
  9. _all_gather_dtensor,
  10. _create_chunk_dtensor,
  11. _create_chunk_sharded_tensor,
  12. )
  13. class FSDPExtensions(ABC):
  14. """
  15. This enables some customizable hooks to enable composability with tensor
  16. parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to
  17. set a custom :class:`FSDPExtensions` that implements the hooks.
  18. """
  19. @abstractmethod
  20. def pre_flatten_transform(
  21. self,
  22. tensor: torch.Tensor,
  23. ) -> Tuple[torch.Tensor, Optional[Any]]:
  24. """E.g. converting ``DistributedTensor`` to local tensor."""
  25. ...
  26. @abstractmethod
  27. def post_unflatten_transform(
  28. self,
  29. tensor: torch.Tensor,
  30. param_extension: Any,
  31. ) -> torch.Tensor:
  32. """E.g. converting local tensor to ``DistributedTensor``."""
  33. ...
  34. @abstractmethod
  35. def chunk_tensor(
  36. self,
  37. tensor: torch.Tensor,
  38. rank: int,
  39. world_size: int,
  40. num_devices_per_node: int,
  41. pg: dist.ProcessGroup,
  42. device: Optional[torch.device] = None,
  43. ) -> torch.Tensor:
  44. """Shards a tensor to chunks and returns the local chunk."""
  45. ...
  46. @abstractmethod
  47. def chunk_dtensor(
  48. self,
  49. tensor: torch.Tensor,
  50. rank: int,
  51. device_mesh: DeviceMesh,
  52. ) -> torch.Tensor:
  53. """Shards a tensor/DTensor to DTensor and returns the local DTensor."""
  54. ...
  55. @abstractmethod
  56. def pre_load_state_dict_transform(
  57. self,
  58. tensor: torch.Tensor,
  59. ) -> Tuple[torch.Tensor, List[Shard]]:
  60. """
  61. This is to be called before loading a *sharded* model state dict and
  62. should return the tensor and list of shards from which to load data.
  63. """
  64. ...
  65. @abstractmethod
  66. def all_gather_dtensor(
  67. self,
  68. tensor: DTensor,
  69. parent_mesh: Optional[DeviceMesh],
  70. ) -> torch.Tensor:
  71. """
  72. This is to be called before loading a *sharded* DTensor state dict.
  73. This gathers tensor in FSDP dimension and returns local tensor of
  74. TP DTensor.
  75. """
  76. ...
  77. _extensions: Optional[FSDPExtensions] = None
  78. def _set_fsdp_extensions(flattener: FSDPExtensions) -> None:
  79. global _extensions
  80. _extensions = flattener
  81. def _ext_pre_flatten_transform(
  82. tensor: torch.Tensor,
  83. fsdp_extension: Optional[FSDPExtensions] = None,
  84. ) -> Tuple[torch.Tensor, Optional[Any]]:
  85. if fsdp_extension is not None:
  86. new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor)
  87. if param_extension is not None:
  88. return new_tensor, param_extension
  89. return tensor, None
  90. def _ext_post_unflatten_transform(
  91. tensor: torch.Tensor,
  92. param_extension: Any,
  93. fsdp_extension: Optional[FSDPExtensions] = None,
  94. ) -> torch.Tensor:
  95. if fsdp_extension is not None and param_extension is not None:
  96. return fsdp_extension.post_unflatten_transform(tensor, param_extension)
  97. return tensor
  98. def _ext_chunk_tensor(
  99. tensor: torch.Tensor,
  100. rank: int,
  101. world_size: int,
  102. num_devices_per_node: int,
  103. pg: dist.ProcessGroup,
  104. fsdp_extension: Optional[FSDPExtensions] = None,
  105. ) -> torch.Tensor:
  106. chunk_tensor_fn = (
  107. fsdp_extension.chunk_tensor
  108. if fsdp_extension is not None
  109. else _create_chunk_sharded_tensor
  110. )
  111. return chunk_tensor_fn(
  112. tensor,
  113. rank,
  114. world_size,
  115. num_devices_per_node,
  116. pg,
  117. )
  118. def _ext_chunk_dtensor(
  119. tensor: torch.Tensor,
  120. rank: int,
  121. device_mesh: DeviceMesh,
  122. fsdp_extension: Optional[FSDPExtensions] = None,
  123. ) -> torch.Tensor:
  124. chunk_dtensor_fn = (
  125. fsdp_extension.chunk_dtensor
  126. if fsdp_extension is not None
  127. else _create_chunk_dtensor
  128. )
  129. return chunk_dtensor_fn(
  130. tensor,
  131. rank,
  132. device_mesh,
  133. )
  134. def _ext_pre_load_state_dict_transform(
  135. tensor: torch.Tensor,
  136. fsdp_extension: Optional[FSDPExtensions] = None,
  137. ) -> Tuple[torch.Tensor, List[Shard]]:
  138. if fsdp_extension is not None:
  139. return fsdp_extension.pre_load_state_dict_transform(tensor)
  140. assert type(tensor) is ShardedTensor
  141. shards = tensor.local_shards()
  142. return (tensor, shards)
  143. def _ext_all_gather_dtensor(
  144. tensor: DTensor,
  145. parent_mesh: Optional[DeviceMesh],
  146. fsdp_extension: Optional[FSDPExtensions] = None,
  147. ) -> torch.Tensor:
  148. all_gather_dtensor_fn = (
  149. fsdp_extension.all_gather_dtensor
  150. if fsdp_extension is not None
  151. else _all_gather_dtensor
  152. )
  153. return all_gather_dtensor_fn(tensor, parent_mesh)