| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- # mypy: allow-untyped-defs
- import os
- from dataclasses import dataclass, field
- from enum import Enum
- from typing import Any, Dict, List, Optional, Sequence, Union
- import torch
- from torch.distributed.checkpoint.stateful import StatefulT
- __all__ = [
- "ChunkStorageMetadata",
- "TensorStorageMetadata",
- "BytesStorageMetadata",
- "Metadata",
- "MetadataIndex",
- "TensorProperties",
- "StorageMeta",
- ]
- @dataclass
- class ChunkStorageMetadata:
- """
- Each chunk is expected to have the same properties of the TensorStorageMetadata
- that includes it.
- """
- offsets: torch.Size
- sizes: torch.Size
- class _MEM_FORMAT_ENCODING(Enum):
- """Describe the memory format of a tensor."""
- TORCH_CONTIGUOUS_FORMAT = 0
- TORCH_CHANNELS_LAST = 1
- TORCH_PRESERVE_FORMAT = 2
- @dataclass
- class TensorProperties:
- """Properties used to create :class:`Tensor`"""
- # Regular tensor fields
- dtype: torch.dtype = field(default_factory=torch.get_default_dtype)
- # This field is deprecated.
- layout: torch.layout = field(default=torch.strided)
- # This field is deprecated.
- requires_grad: bool = False
- # This field is deprecated.
- memory_format: torch.memory_format = field(default=torch.contiguous_format)
- # This field is deprecated.
- pin_memory: bool = False
- def __getstate__(self):
- # Since torch.memory_format cannot be pickled!
- memory_format = self.memory_format
- if memory_format == torch.contiguous_format:
- mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT
- elif memory_format == torch.channels_last:
- mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST
- elif memory_format == torch.preserve_format:
- mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT
- else:
- raise RuntimeError(f"Invalid torch.memory_format: {memory_format}")
- return (
- self.dtype,
- self.layout,
- self.requires_grad,
- mem_format_encoding,
- self.pin_memory,
- )
- def __setstate__(
- self,
- state,
- ):
- (
- self.dtype,
- self.layout,
- self.requires_grad,
- mem_format_encoding,
- self.pin_memory,
- ) = state
- if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT:
- memory_format = torch.contiguous_format
- elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST:
- memory_format = torch.channels_last
- elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT:
- memory_format = torch.preserve_format
- else:
- raise RuntimeError(
- f"Invalid torch.memory_format encoding: {mem_format_encoding}"
- )
- self.memory_format = memory_format
- @staticmethod
- def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties":
- return TensorProperties(
- dtype=tensor.dtype,
- layout=tensor.layout,
- requires_grad=tensor.requires_grad,
- memory_format=torch.contiguous_format,
- pin_memory=tensor.is_pinned(),
- )
- @dataclass
- class TensorStorageMetadata:
- properties: TensorProperties
- size: torch.Size
- chunks: List[ChunkStorageMetadata]
- @dataclass
- class BytesStorageMetadata:
- pass
- STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
- STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]]
- @dataclass
- class StorageMeta:
- checkpoint_id: Union[str, os.PathLike, None] = None
- save_id: Optional[str] = None
- load_id: Optional[str] = None
- @dataclass
- class Metadata:
- """This class represents the metadata of the checkpoint."""
- # Keys are the same from the `state_dict` used.
- state_dict_metadata: Dict[str, STORAGE_TYPES]
- # It is the responsibility of the planner and storage plugins to ensure
- # backward compatibility of the planner_data and storage_data. DCP will
- # also ensure the backward compatibility of the metadata in this file and
- # the metadata of the built-in planner and storage plugins.
- planner_data: Any = None
- storage_data: Any = None
- storage_meta: Optional[StorageMeta] = None
- @dataclass(frozen=True)
- class MetadataIndex:
- """This class represents a lookup key for items in a state dict or Metadata."""
- fqn: str
- """Fully Qualified Name of the object"""
- offset: Optional[torch.Size] = None
- """If the object is a tensor, offset into the tensor we're looking for"""
- index: Optional[int] = field(hash=False, compare=False, default=None)
- """
- Index hint when searching for tensor chunk to speedup lookups (optional)
- A common representation of a sharded tensor is as a list of chunks so to
- find the index in such a list you need to linear search it.
- When constructing an instance of MetadataIndex that points to that list,
- one can provide the index as a hint and it will be probed first before
- the linear search and thus making it significantly faster.
- """
- def __init__(
- self,
- fqn: str,
- offset: Optional[Sequence[int]] = None,
- index: Optional[int] = None,
- ):
- # We must use object.__setattr__ due to frozen=True
- object.__setattr__(self, "fqn", fqn)
- object.__setattr__(self, "index", index)
- if offset is not None:
- object.__setattr__(self, "offset", torch.Size(offset))
|