utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # mypy: allow-untyped-defs
  2. import cProfile
  3. import inspect
  4. import io
  5. import itertools
  6. import os
  7. import warnings
  8. from contextlib import contextmanager
  9. from functools import wraps
  10. from pstats import Stats
  11. from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
  12. import torch
  13. import torch.distributed as dist
  14. from torch.distributed._shard.sharded_tensor import ShardedTensor
  15. from torch.distributed._shard.sharded_tensor.shard import Shard
  16. from torch.distributed._tensor import DTensor
  17. from torch.distributed.checkpoint.planner import _Checkpointable
  18. from .api import (
  19. _is_wrapped_exception,
  20. _wrap_exception,
  21. CheckpointException,
  22. WRAPPED_EXCEPTION,
  23. )
  24. from .metadata import MetadataIndex, STATE_DICT_TYPE
  25. __all__ = ["find_tensor_shard", "find_state_dict_object"]
  26. T = TypeVar("T")
  27. R = TypeVar("R")
  28. def _get_failure_dict(
  29. results: List[Union[T, WRAPPED_EXCEPTION]]
  30. ) -> Dict[int, WRAPPED_EXCEPTION]:
  31. return cast(
  32. Dict[int, WRAPPED_EXCEPTION],
  33. {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
  34. )
  35. def _all_gather_keys(
  36. local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None
  37. ) -> List[Any]:
  38. """Gathers all keys, and returns them sorted."""
  39. keys = list(local_dict.keys())
  40. gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item]
  41. dist.all_gather_object(gathered_keys, keys, group=group)
  42. return sorted(set(itertools.chain.from_iterable(gathered_keys)))
  43. class _DistWrapper:
  44. """
  45. This is a wrapper around PG that provides a series of features around object collectives.
  46. It works without distributed initialized, where most collectives turns into nops.
  47. All variants that take functions are exception robust, meaning that if one or more
  48. ranks raise errors, all ranks will observe those.
  49. """
  50. def __init__(
  51. self,
  52. group: Optional[dist.ProcessGroup],
  53. use_dist: bool,
  54. coordinator_rank: int,
  55. ):
  56. self.group = group
  57. self.use_dist = use_dist
  58. self.coordinator_rank = coordinator_rank
  59. if self.use_dist:
  60. self.rank = dist.get_rank(group)
  61. self.is_coordinator = self.rank == coordinator_rank
  62. else:
  63. self.rank = 0
  64. self.is_coordinator = True
  65. def get_rank(self) -> int:
  66. return self.rank
  67. def get_world_size(self) -> int:
  68. if self.use_dist:
  69. return dist.get_world_size(self.group)
  70. return 1
  71. def broadcast_object(self, object: Optional[T]) -> T:
  72. """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
  73. object_list = [object]
  74. if self.use_dist:
  75. dist.broadcast_object_list(
  76. object_list=object_list,
  77. group=self.group,
  78. src=self.coordinator_rank,
  79. )
  80. return cast(T, object_list[0])
  81. def gather_object(self, object: T) -> Optional[List[T]]:
  82. """Implement functionality similar to c10d::gather_object but without distributed enabled."""
  83. if self.use_dist:
  84. gather_objs = (
  85. cast(List[T], [None] * dist.get_world_size(self.group))
  86. if self.is_coordinator
  87. else None
  88. )
  89. dist.gather_object(
  90. obj=object,
  91. object_gather_list=gather_objs if self.is_coordinator else None,
  92. dst=self.coordinator_rank,
  93. group=self.group,
  94. )
  95. result = gather_objs
  96. else:
  97. result = [object]
  98. return result
  99. def all_gather_object(self, object: T) -> List[T]:
  100. """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
  101. if self.use_dist:
  102. gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
  103. dist.all_gather_object(
  104. object_list=gather_objs, obj=object, group=self.group
  105. )
  106. else:
  107. gather_objs = [object]
  108. return gather_objs
  109. def scatter_object(self, object_list: Optional[List[T]]) -> T:
  110. """Implement functionality similar to c10d::scatter_object but without distributed enabled."""
  111. if self.use_dist:
  112. gather_result = cast(List[T], [None])
  113. dist.scatter_object_list(
  114. scatter_object_output_list=gather_result,
  115. scatter_object_input_list=object_list if self.is_coordinator else None,
  116. src=self.coordinator_rank,
  117. group=self.group,
  118. )
  119. local_reply = gather_result[0]
  120. else:
  121. assert object_list is not None
  122. local_reply = object_list[0]
  123. return local_reply
  124. def reduce_scatter(
  125. self,
  126. step: str,
  127. map_fun: Callable[[], T],
  128. reduce_fun: Callable[[List[T]], List[R]],
  129. ) -> R:
  130. """
  131. Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
  132. This method operates in the following way:
  133. Run ``map_fun`` on all ranks
  134. Gather results on rank 0
  135. Call ``reduce_fun`` on all those values
  136. Scatter to each rank part of the result.
  137. """
  138. local_data: Union[WRAPPED_EXCEPTION, T]
  139. try:
  140. local_data = map_fun()
  141. except BaseException as e:
  142. local_data = _wrap_exception(e)
  143. all_data = self.gather_object(local_data)
  144. all_results: Optional[List[Union[R, CheckpointException]]] = None
  145. if self.is_coordinator:
  146. assert all_data is not None
  147. node_failures = _get_failure_dict(all_data)
  148. if len(node_failures) == 0:
  149. try:
  150. # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
  151. all_results = cast(
  152. List[Union[R, CheckpointException]],
  153. reduce_fun(cast(List[T], all_data)),
  154. )
  155. except BaseException as e:
  156. node_failures[self.rank] = _wrap_exception(e)
  157. if len(node_failures) > 0:
  158. all_results = [
  159. CheckpointException(step, node_failures)
  160. ] * self.get_world_size()
  161. result = self.scatter_object(all_results)
  162. if isinstance(result, CheckpointException):
  163. raise result
  164. return result
  165. def all_reduce(
  166. self,
  167. step: str,
  168. map_fun: Callable[[], T],
  169. reduce_fun: Callable[[List[T]], R],
  170. ) -> R:
  171. """
  172. Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
  173. This method operates in the following way:
  174. Run ``map_fun`` on all ranks
  175. Gather results on rank 0
  176. Call ``reduce_fun`` on all those values
  177. Broadcast the reduced value to all ranks.
  178. """
  179. local_data: Union[T, WRAPPED_EXCEPTION]
  180. try:
  181. local_data = map_fun()
  182. except BaseException as e:
  183. local_data = _wrap_exception(e)
  184. all_data = self.gather_object(local_data)
  185. result: Optional[Union[R, CheckpointException]] = None
  186. if self.is_coordinator:
  187. assert all_data is not None
  188. node_failures = _get_failure_dict(all_data)
  189. if len(node_failures) == 0:
  190. try:
  191. result = reduce_fun(cast(List[T], all_data))
  192. except BaseException as e:
  193. node_failures[self.rank] = _wrap_exception(e)
  194. if len(node_failures) > 0:
  195. result = CheckpointException(step, node_failures)
  196. final_result = self.broadcast_object(result)
  197. if isinstance(final_result, CheckpointException):
  198. raise final_result
  199. return cast(R, final_result)
  200. def all_gather(
  201. self,
  202. step: str,
  203. map_fun: Callable[[], T],
  204. ) -> List[T]:
  205. """
  206. Compute a value on each rank, then all_gather them.
  207. This method operates in the following way:
  208. Run ``map_cp`` on all ranks
  209. all_gather the values to all ranks
  210. """
  211. result: Union[T, WRAPPED_EXCEPTION]
  212. try:
  213. result = map_fun()
  214. except BaseException as e:
  215. result = _wrap_exception(e)
  216. all_results = self.all_gather_object(result)
  217. node_failures = _get_failure_dict(all_results)
  218. if len(node_failures) > 0:
  219. raise CheckpointException(step, node_failures)
  220. return cast(List[T], all_results)
  221. def broadcast(
  222. self,
  223. step: str,
  224. map_fun: Callable[[], T],
  225. ) -> T:
  226. """
  227. Compute a value on rank 0 and broadcast it.
  228. This method operates in the following way:
  229. Run ``map_cp`` on rank 0
  230. broadcast the value
  231. """
  232. result: Optional[Union[T, CheckpointException]] = None
  233. if self.is_coordinator:
  234. try:
  235. result = map_fun()
  236. except BaseException as e:
  237. result = CheckpointException(step, {self.rank: _wrap_exception(e)})
  238. final_result = self.broadcast_object(result)
  239. if isinstance(final_result, CheckpointException):
  240. raise final_result
  241. return cast(T, final_result)
  242. def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
  243. if index.offset is None:
  244. raise ValueError(
  245. f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
  246. )
  247. shards = tensor.local_shards()
  248. # index fast path
  249. if index.index is not None:
  250. if (
  251. len(shards) > index.index
  252. and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
  253. ):
  254. return shards[index.index]
  255. for shard in shards:
  256. if torch.Size(shard.metadata.shard_offsets) == index.offset:
  257. return shard
  258. raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
  259. def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
  260. if isinstance(tensor, _Checkpointable):
  261. return tensor._get_tensor_shard(tensor, index)
  262. elif isinstance(tensor, DTensor):
  263. # DTensor can contain a local tensor that is a tensor subclass
  264. if isinstance(tensor.to_local(), _Checkpointable):
  265. return tensor.to_local()._get_tensor_shard(tensor, index) # type: ignore[arg-type]
  266. return tensor.to_local()
  267. if isinstance(tensor, ShardedTensor):
  268. return _find_shard(tensor, index).tensor
  269. if index.offset is not None:
  270. # special case looking up a tensor by origin
  271. if index.offset == torch.Size([0] * len(tensor.size())):
  272. return tensor
  273. raise ValueError(
  274. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  275. )
  276. return tensor
  277. def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
  278. if index.fqn not in state_dict:
  279. raise ValueError(f"Could not find FQN: '{index.fqn}'")
  280. obj = state_dict[index.fqn]
  281. if isinstance(obj, torch.Tensor):
  282. return find_tensor_shard(obj, index)
  283. elif index.offset is not None:
  284. raise ValueError(
  285. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  286. )
  287. return obj
  288. def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]:
  289. return [i_a + i_b for i_a, i_b in zip(a, b)]
  290. def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
  291. return [i_a - i_b for i_a, i_b in zip(a, b)]
  292. class _ReaderView(io.IOBase):
  293. def __init__(self, base_stream: io.IOBase, offset: int, len: int):
  294. super().__init__()
  295. self.offset = offset
  296. self.len = len
  297. self.base_stream = base_stream
  298. self.seek(0)
  299. def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
  300. if __whence == os.SEEK_SET:
  301. __offset = self.offset + __offset
  302. elif __whence == os.SEEK_END:
  303. __whence = os.SEEK_SET
  304. __offset = (self.offset + self.len) - __offset
  305. return self.base_stream.seek(__offset, __whence)
  306. def tell(self) -> int:
  307. return self.base_stream.tell() - self.offset
  308. def readable(self) -> bool:
  309. return self.base_stream.readable()
  310. def seekable(self) -> bool:
  311. return self.base_stream.seekable()
  312. def readinto(self, b):
  313. return self.base_stream.readinto(b) # type: ignore[attr-defined]
  314. def read(self, size=-1):
  315. return self.base_stream.read(size)
  316. def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
  317. # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
  318. return _ReaderView(file, offset, length)
  319. def _normalize_device_info(device_type: str, device_id: int) -> str:
  320. """Device info normalization."""
  321. if device_type == "cpu":
  322. return "cpu"
  323. return f"{device_type}:{device_id}"
  324. # TODO: integrate with distributed logging flag
  325. ENABLE_PROFILE = False
  326. @contextmanager
  327. def _profile():
  328. # Only log the profiling when it is enable and is on rank0 or dist is not
  329. # avaiable.
  330. if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
  331. profiler = cProfile.Profile()
  332. profiler.enable()
  333. try:
  334. yield
  335. finally:
  336. profiler.disable()
  337. stats = Stats(profiler)
  338. stats.sort_stats("time").print_stats(10)
  339. else:
  340. yield
  341. def _api_bc_check(func):
  342. @wraps(func)
  343. def inner_func(*args, **kwargs) -> Any:
  344. if len(args) == 2:
  345. warnings.warn(
  346. f"The argument order of {func.__name__} has been changed. "
  347. "Please check the document to avoid future breakages."
  348. )
  349. sig = inspect.signature(func)
  350. kwonlyargs = [
  351. p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
  352. ]
  353. if "storage_writer" in kwonlyargs:
  354. assert "storage_writer" not in kwargs, (args, kwargs)
  355. kwargs["storage_writer"] = args[1]
  356. elif "storage_reader" in kwonlyargs:
  357. assert "storage_reader" not in kwargs, (args, kwargs)
  358. kwargs["storage_reader"] = args[1]
  359. else:
  360. raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
  361. return func(args[0], **kwargs)
  362. else:
  363. return func(*args, **kwargs)
  364. return inner_func