metadata.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. import os
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Any, Dict, List, Optional, Sequence, Union
  6. import torch
  7. from torch.distributed.checkpoint.stateful import StatefulT
  8. __all__ = [
  9. "ChunkStorageMetadata",
  10. "TensorStorageMetadata",
  11. "BytesStorageMetadata",
  12. "Metadata",
  13. "MetadataIndex",
  14. "TensorProperties",
  15. "StorageMeta",
  16. ]
  17. @dataclass
  18. class ChunkStorageMetadata:
  19. """
  20. Each chunk is expected to have the same properties of the TensorStorageMetadata
  21. that includes it.
  22. """
  23. offsets: torch.Size
  24. sizes: torch.Size
  25. class _MEM_FORMAT_ENCODING(Enum):
  26. """Describe the memory format of a tensor."""
  27. TORCH_CONTIGUOUS_FORMAT = 0
  28. TORCH_CHANNELS_LAST = 1
  29. TORCH_PRESERVE_FORMAT = 2
  30. @dataclass
  31. class TensorProperties:
  32. """Properties used to create :class:`Tensor`"""
  33. # Regular tensor fields
  34. dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
  35. # This field is deprecated.
  36. layout: torch.layout = field(default=torch.strided)
  37. # This field is deprecated.
  38. requires_grad: bool = False
  39. # This field is deprecated.
  40. memory_format: torch.memory_format = field(default=torch.contiguous_format)
  41. # This field is deprecated.
  42. pin_memory: bool = False
  43. def __getstate__(self):
  44. # Since torch.memory_format cannot be pickled!
  45. memory_format = self.memory_format
  46. if memory_format == torch.contiguous_format:
  47. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
  48. elif memory_format == torch.channels_last:
  49. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
  50. elif memory_format == torch.preserve_format:
  51. mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
  52. else:
  53. raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
  54. return (
  55. self.dtype,
  56. self.layout,
  57. self.requires_grad,
  58. mem_format_encoding,
  59. self.pin_memory,
  60. )
  61. def __setstate__(
  62. self,
  63. state,
  64. ):
  65. (
  66. self.dtype,
  67. self.layout,
  68. self.requires_grad,
  69. mem_format_encoding,
  70. self.pin_memory,
  71. ) = state
  72. if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
  73. memory_format = torch.contiguous_format
  74. elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
  75. memory_format = torch.channels_last
  76. elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
  77. memory_format = torch.preserve_format
  78. else:
  79. raise RuntimeError(
  80. f"Invalid torch.memory_format encoding: {mem_format_encoding}"
  81. )
  82. self.memory_format = memory_format
  83. @staticmethod
  84. def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
  85. return TensorProperties(
  86. dtype=tensor.dtype,
  87. layout=tensor.layout,
  88. requires_grad=tensor.requires_grad,
  89. memory_format=torch.contiguous_format,
  90. pin_memory=tensor.is_pinned(),
  91. )
  92. @dataclass
  93. class TensorStorageMetadata:
  94. properties: TensorProperties
  95. size: torch.Size
  96. chunks: List[ChunkStorageMetadata]
  97. @dataclass
  98. class BytesStorageMetadata:
  99. pass
  100. STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
  101. STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
  102. @dataclass
  103. class StorageMeta:
  104. checkpoint_id: Union[str, os.PathLike, None] = None
  105. save_id: Optional[str] = None
  106. load_id: Optional[str] = None
  107. @dataclass
  108. class Metadata:
  109. """This class represents the metadata of the checkpoint."""
  110. # Keys are the same from the `state_dict` used.
  111. state_dict_metadata: Dict[str, STORAGE_TYPES]
  112. # It is the responsibility of the planner and storage plugins to ensure
  113. # backward compatibility of the planner_data and storage_data. DCP will
  114. # also ensure the backward compatibility of the metadata in this file and
  115. # the metadata of the built-in planner and storage plugins.
  116. planner_data: Any = None
  117. storage_data: Any = None
  118. storage_meta: Optional[StorageMeta] = None
  119. @dataclass(frozen=True)
  120. class MetadataIndex:
  121. """This class represents a lookup key for items in a state dict or Metadata."""
  122. fqn: str
  123. """Fully Qualified Name of the object"""
  124. offset: Optional[torch.Size] = None
  125. """If the object is a tensor, offset into the tensor we're looking for"""
  126. index: Optional[int] = field(hash=False, compare=False, default=None)
  127. """
  128. Index hint when searching for tensor chunk to speedup lookups (optional)
  129. A common representation of a sharded tensor is as a list of chunks so to
  130. find the index in such a list you need to linear search it.
  131. When constructing an instance of MetadataIndex that points to that list,
  132. one can provide the index as a hint and it will be probed first before
  133. the linear search and thus making it significantly faster.
  134. """
  135. def __init__(
  136. self,
  137. fqn: str,
  138. offset: Optional[Sequence[int]] = None,
  139. index: Optional[int] = None,
  140. ):
  141. # We must use object.__setattr__ due to frozen=True
  142. object.__setattr__(self, "fqn", fqn)
  143. object.__setattr__(self, "index", index)
  144. if offset is not None:
  145. object.__setattr__(self, "offset", torch.Size(offset))