| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536 |
- # mypy: allow-untyped-defs
- import difflib
- import functools
- import os
- import io
- import re
- import shutil
- import struct
- import sys
- import torch
- import tarfile
- import tempfile
- import warnings
- from contextlib import closing, contextmanager
- from enum import Enum
- from ._utils import _import_dotted_name
- from torch._sources import get_source_lines_and_file
- from torch.types import Storage
- from torch.storage import _get_dtype_from_pickle_storage_type
- from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
- from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
- import copyreg
- import pickle
- import torch._weights_only_unpickler as _weights_only_unpickler
- DEFAULT_PROTOCOL = 2
- LONG_SIZE = struct.Struct('=l').size
- INT_SIZE = struct.Struct('=i').size
- SHORT_SIZE = struct.Struct('=h').size
- MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
- PROTOCOL_VERSION = 1001
- STORAGE_KEY_SEPARATOR = ','
- FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
- MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
- STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
- IS_WINDOWS = sys.platform == "win32"
- if not IS_WINDOWS:
- from mmap import MAP_SHARED, MAP_PRIVATE
- else:
- MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
- __all__ = [
- 'SourceChangeWarning',
- 'mkdtemp',
- 'register_package',
- 'check_module_version_greater_or_equal',
- 'validate_cuda_device',
- 'validate_hpu_device',
- 'location_tag',
- 'default_restore_location',
- 'normalize_storage_type',
- 'storage_to_tensor_type',
- 'save',
- 'load',
- 'StorageType',
- 'LoadEndianness',
- 'get_default_load_endianness',
- 'set_default_load_endianness',
- 'clear_safe_globals',
- 'get_safe_globals',
- 'add_safe_globals',
- ]
- class SourceChangeWarning(Warning):
- pass
- @contextmanager
- def mkdtemp():
- path = tempfile.mkdtemp()
- try:
- yield path
- finally:
- shutil.rmtree(path)
- _package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
- class LoadEndianness(Enum):
- NATIVE = 1
- LITTLE = 2
- BIG = 3
- _default_load_endian: Optional[LoadEndianness] = None
- def get_default_load_endianness() -> Optional[LoadEndianness]:
- '''
- Get fallback byte order for loading files
- If byteorder mark is not present in saved checkpoint,
- this byte order is used as fallback.
- By default, it's "native" byte order.
- Returns:
- default_load_endian: Optional[LoadEndianness]
- '''
- return _default_load_endian
- def set_default_load_endianness(endianness):
- '''
- Set fallback byte order for loading files
- If byteorder mark is not present in saved checkpoint,
- this byte order is used as fallback.
- By default, it's "native" byte order.
- Args:
- endianness: the new fallback byte order
- '''
- global _default_load_endian
- if not isinstance(endianness, LoadEndianness) and endianness is not None:
- raise TypeError("Invalid argument type in function set_default_load_endianness")
- _default_load_endian = endianness
- _default_mmap_options: int = MAP_PRIVATE
- def get_default_mmap_options() -> int:
- '''
- Get default mmap options for :func:`torch.load` with ``mmap=True``.
- Defaults to ``mmap.MAP_PRIVATE``.
- Returns:
- default_mmap_options: int
- '''
- return _default_mmap_options
- def set_default_mmap_options(flags: int):
- '''
- Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
- For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
- Please open an issue if you need any other option to be added here.
- .. note::
- This feature is currently not supported for Windows.
- Args:
- flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
- '''
- global _default_mmap_options
- if IS_WINDOWS:
- raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
- if (flags != MAP_PRIVATE and flags != MAP_SHARED):
- raise ValueError("Invalid argument in function set_default_mmap_options, "
- f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
- _default_mmap_options = flags
- def clear_safe_globals() -> None:
- '''
- Clears the list of globals that are safe for ``weights_only`` load.
- '''
- _weights_only_unpickler._clear_safe_globals()
- def get_safe_globals() -> List[Any]:
- '''
- Returns the list of user-added globals that are safe for ``weights_only`` load.
- '''
- return _weights_only_unpickler._get_safe_globals()
- def add_safe_globals(safe_globals: List[Any]) -> None:
- '''
- Marks the given globals as safe for ``weights_only`` load. For example, functions
- added to this list can be called during unpickling, classes could be instantiated
- and have state set.
- Args:
- safe_globals (List[Any]): list of globals to mark as safe
- Example:
- >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
- >>> import tempfile
- >>> class MyTensor(torch.Tensor):
- ... pass
- >>> t = MyTensor(torch.randn(2, 3))
- >>> with tempfile.NamedTemporaryFile() as f:
- ... torch.save(t, f.name)
- # Running `torch.load(f.name, weights_only=True)` will fail with
- # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
- # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
- ... torch.serialization.add_safe_globals([MyTensor])
- ... torch.load(f.name, weights_only=True)
- # MyTensor([[-0.5024, -1.8152, -0.5455],
- # [-0.8234, 2.0500, -0.3657]])
- '''
- _weights_only_unpickler._add_safe_globals(safe_globals)
- def _is_zipfile(f) -> bool:
- # This is a stricter implementation than zipfile.is_zipfile().
- # zipfile.is_zipfile() is True if the magic number appears anywhere in the
- # binary. Since we expect the files here to be generated by torch.save or
- # torch.jit.save, it's safe to only check the start bytes and avoid
- # collisions and assume the zip has only 1 file.
- # See bugs.python.org/issue28494.
- start = f.tell()
- # Read the first few bytes and match against the ZIP file signature
- local_header_magic_number = b'PK\x03\x04'
- read_bytes = f.read(len(local_header_magic_number))
- f.seek(start)
- return read_bytes == local_header_magic_number
- def register_package(
- priority: int,
- tagger: Callable[[STORAGE], Optional[str]],
- deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
- ):
- '''
- Registers callables for tagging and deserializing storage objects with an associated priority.
- Tagging associates a device with a storage object at save time while deserializing moves a
- storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
- are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
- value that is not `None`.
- To override the deserialization behavior for a device in the global registry, one can register a
- tagger with a higher priority than the existing tagger.
- This function can also be used to register a tagger and deserializer for new devices.
- Args:
- priority: Indicates the priority associated with the tagger and deserializer, where a lower
- value indicates higher priority.
- tagger: Callable that takes in a storage object and returns its tagged device as a string
- or None.
- deserializer: Callable that takes in storage object and a device string and returns a storage
- object on the appropriate device or None.
- Returns:
- `None`
- Example:
- >>> def ipu_tag(obj):
- >>> if obj.device.type == 'ipu':
- >>> return 'ipu'
- >>> def ipu_deserialize(obj, location):
- >>> if location.startswith('ipu'):
- >>> ipu = getattr(torch, "ipu", None)
- >>> assert ipu is not None, "IPU device module is not loaded"
- >>> assert torch.ipu.is_available(), "ipu is not available"
- >>> return obj.ipu(location)
- >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- '''
- queue_elem = (priority, tagger, deserializer)
- _package_registry.append(queue_elem)
- _package_registry.sort()
- def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
- '''
- Check if a module's version satisfies requirements
- Usually, a module's version string will be like 'x.y.z', which would be represented
- as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
- string does not match the given tuple's format up to the length of the tuple, then
- error and exit or emit a warning.
- Args:
- module: the module to check the version of
- req_version_tuple: tuple (usually of ints) representing the required version
- error_if_malformed: whether we should exit if module version string is malformed
- Returns:
- requirement_is_met: bool
- '''
- try:
- version_strs = module.__version__.split('.')
- # Cast module version fields to match the types of the required version
- module_version = tuple(
- type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
- )
- requirement_is_met = module_version >= req_version_tuple
- except Exception as e:
- message = (
- f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
- f" with tuple {str(req_version_tuple)}"
- )
- if error_if_malformed:
- raise RuntimeError(message) from e
- else:
- warnings.warn(message + ', but continuing assuming that requirement is met')
- requirement_is_met = True
- return requirement_is_met
- def _cpu_tag(obj):
- if obj.device.type == 'cpu':
- return 'cpu'
- def _mps_tag(obj):
- if obj.device.type == 'mps':
- return 'mps'
- def _meta_tag(obj):
- if obj.device.type == 'meta':
- return 'meta'
- def _backend_tag(backend_name, obj):
- if backend_name == 'privateuse1':
- backend_name = torch._C._get_privateuse1_backend_name()
- if obj.device.type == backend_name:
- if obj.device.index is None:
- return backend_name
- else:
- return backend_name + ':' + str(obj.device.index)
- def _cpu_deserialize(obj, location):
- if location == 'cpu':
- return obj
- def _mps_deserialize(obj, location):
- if location.startswith('mps'):
- return obj.mps()
- def _meta_deserialize(obj, location):
- if location == 'meta':
- return torch.UntypedStorage(obj.nbytes(), device='meta')
- def _validate_device(location, backend_name):
- '''
- Check whether the device index of specified backend is valid
- In case of privateuse1 backend, your must first register a device_module for
- privateuse1 using torch._register_device_module. Implement the following
- methods in device_module like cuda: device_module._utils._get_device_index(location, True),
- device_module.device_count().
- Args:
- location: string of device
- backend_name: the backend name or the name of privateuse1, which can be renamed
- Returns:
- device_index: int
- '''
- if not hasattr(torch, backend_name):
- raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
- 'If you are running on a CPU-only machine, '
- 'please use torch.load with map_location=torch.device(\'cpu\') '
- 'to map your storages to the CPU.')
- device_module = getattr(torch, backend_name)
- if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
- device_index = device_module._utils._get_device_index(location, True)
- device = torch.device(backend_name, device_index)
- else:
- device = torch.device(location)
- device_index = device.index if device.index else 0
- if hasattr(device_module, 'is_available') and not device_module.is_available():
- raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
- f'device but torch.{backend_name}.is_available() is False. '
- 'If you are running on a CPU-only machine, '
- 'please use torch.load with map_location=torch.device(\'cpu\') '
- 'to map your storages to the CPU.')
- if hasattr(device_module, 'device_count'):
- device_count = device_module.device_count()
- if device_index >= device_count:
- raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
- f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
- 'Please use torch.load with map_location to map your storages '
- 'to an existing device.')
- return device
- def validate_cuda_device(location):
- return _validate_device(location, 'cuda').index
- def validate_hpu_device(location):
- return _validate_device(location, 'hpu').index
- def _deserialize(backend_name, obj, location):
- if backend_name == 'privateuse1':
- backend_name = torch._C._get_privateuse1_backend_name()
- if location.startswith(backend_name):
- device = _validate_device(location, backend_name)
- return obj.to(device=device)
- register_package(10, _cpu_tag, _cpu_deserialize)
- register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
- register_package(21, _mps_tag, _mps_deserialize)
- register_package(22, _meta_tag, _meta_deserialize)
- register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
- register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
- register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
- def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
- for _, tagger, _ in _package_registry:
- location = tagger(storage)
- if location:
- return location
- raise RuntimeError("don't know how to determine data location of "
- + torch.typename(storage))
- def default_restore_location(storage, location):
- for _, _, fn in _package_registry:
- result = fn(storage, location)
- if result is not None:
- return result
- raise RuntimeError("don't know how to restore data location of "
- + torch.typename(storage) + " (tagged with "
- + location + ")")
- def normalize_storage_type(storage_type):
- return getattr(torch, storage_type.__name__)
- def storage_to_tensor_type(storage):
- storage_type = type(storage)
- module = _import_dotted_name(storage_type.__module__)
- return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
- def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
- return isinstance(name_or_buffer, (str, os.PathLike))
- class _opener:
- def __init__(self, file_like):
- self.file_like = file_like
- def __enter__(self):
- return self.file_like
- def __exit__(self, *args):
- pass
- class _open_file(_opener):
- def __init__(self, name, mode):
- super().__init__(open(name, mode))
- def __exit__(self, *args):
- self.file_like.close()
- class _open_buffer_reader(_opener):
- def __init__(self, buffer):
- super().__init__(buffer)
- _check_seekable(buffer)
- class _open_buffer_writer(_opener):
- def __exit__(self, *args):
- self.file_like.flush()
- def _open_file_like(name_or_buffer, mode):
- if _is_path(name_or_buffer):
- return _open_file(name_or_buffer, mode)
- else:
- if 'w' in mode:
- return _open_buffer_writer(name_or_buffer)
- elif 'r' in mode:
- return _open_buffer_reader(name_or_buffer)
- else:
- raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
- class _open_zipfile_reader(_opener):
- def __init__(self, name_or_buffer) -> None:
- super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
- class _open_zipfile_writer_file(_opener):
- def __init__(self, name) -> None:
- self.file_stream = None
- self.name = str(name)
- try:
- self.name.encode('ascii')
- except UnicodeEncodeError:
- # PyTorchFileWriter only supports ascii filename.
- # For filenames with non-ascii characters, we rely on Python
- # for writing out the file.
- self.file_stream = io.FileIO(self.name, mode='w')
- super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
- else:
- super().__init__(torch._C.PyTorchFileWriter(self.name))
- def __exit__(self, *args) -> None:
- self.file_like.write_end_of_file()
- if self.file_stream is not None:
- self.file_stream.close()
- class _open_zipfile_writer_buffer(_opener):
- def __init__(self, buffer) -> None:
- if not callable(getattr(buffer, "write", None)):
- msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
- if not hasattr(buffer, "write"):
- raise AttributeError(msg)
- raise TypeError(msg)
- self.buffer = buffer
- super().__init__(torch._C.PyTorchFileWriter(buffer))
- def __exit__(self, *args) -> None:
- self.file_like.write_end_of_file()
- self.buffer.flush()
- def _open_zipfile_writer(name_or_buffer):
- container: Type[_opener]
- if _is_path(name_or_buffer):
- container = _open_zipfile_writer_file
- else:
- container = _open_zipfile_writer_buffer
- return container(name_or_buffer)
- def _is_compressed_file(f) -> bool:
- compress_modules = ['gzip']
- try:
- return f.__module__ in compress_modules
- except AttributeError:
- return False
- def _should_read_directly(f):
- """
- Checks if f is a file that should be read directly. It should be read
- directly if it is backed by a real file (has a fileno) and is not a
- a compressed file (e.g. gzip)
- """
- if _is_compressed_file(f):
- return False
- try:
- return f.fileno() >= 0
- except io.UnsupportedOperation:
- return False
- except AttributeError:
- return False
- def _check_seekable(f) -> bool:
- def raise_err_msg(patterns, e):
- for p in patterns:
- if p in str(e):
- msg = (str(e) + ". You can only torch.load from a file that is seekable."
- + " Please pre-load the data into a buffer like io.BytesIO and"
- + " try to load from it instead.")
- raise type(e)(msg)
- raise e
- try:
- f.seek(f.tell())
- return True
- except (io.UnsupportedOperation, AttributeError) as e:
- raise_err_msg(["seek", "tell"], e)
- return False
- def _check_dill_version(pickle_module) -> None:
- '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
- If dill version is lower than 0.3.1, a ValueError is raised.
- Args:
- pickle_module: module used for pickling metadata and objects
- '''
- if pickle_module is not None and pickle_module.__name__ == 'dill':
- required_dill_version = (0, 3, 1)
- if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
- raise ValueError((
- "'torch' supports dill >= {}, but you have dill {}."
- " Please upgrade dill or switch to 'pickle'"
- ).format(
- '.'.join([str(num) for num in required_dill_version]),
- pickle_module.__version__
- ))
- def _check_save_filelike(f):
- if not _is_path(f) and not hasattr(f, 'write'):
- raise AttributeError(
- "expected 'f' to be string, path, or a file-like object with "
- "a 'write' attribute")
- def save(
- obj: object,
- f: FILE_LIKE,
- pickle_module: Any = pickle,
- pickle_protocol: int = DEFAULT_PROTOCOL,
- _use_new_zipfile_serialization: bool = True,
- _disable_byteorder_record: bool = False
- ) -> None:
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
- # the build environment (e.g. `<module 'pickle' from '/leaked/path').
- """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
- Saves an object to a disk file.
- See also: :ref:`saving-loading-tensors`
- Args:
- obj: saved object
- f: a file-like object (has to implement write and flush) or a string or
- os.PathLike object containing a file name
- pickle_module: module used for pickling metadata and objects
- pickle_protocol: can be specified to override the default protocol
- .. note::
- A common PyTorch convention is to save tensors using .pt file extension.
- .. note::
- PyTorch preserves storage sharing across serialization. See
- :ref:`preserve-storage-sharing` for more details.
- .. note::
- The 1.6 release of PyTorch switched ``torch.save`` to use a new
- zipfile-based file format. ``torch.load`` still retains the ability to
- load files in the old format. If for any reason you want ``torch.save``
- to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
- Example:
- >>> # xdoctest: +SKIP("makes cwd dirty")
- >>> # Save to file
- >>> x = torch.tensor([0, 1, 2, 3, 4])
- >>> torch.save(x, 'tensor.pt')
- >>> # Save to io.BytesIO buffer
- >>> buffer = io.BytesIO()
- >>> torch.save(x, buffer)
- """
- torch._C._log_api_usage_once("torch.save")
- _check_dill_version(pickle_module)
- _check_save_filelike(f)
- if _use_new_zipfile_serialization:
- with _open_zipfile_writer(f) as opened_zipfile:
- _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
- return
- else:
- with _open_file_like(f, 'wb') as opened_file:
- _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
- def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
- import torch.nn as nn
- serialized_container_types = {}
- serialized_storages = {}
- # Since loading storages that view the same data with different dtypes is
- # not supported, we need to keep track of the dtype associated with each
- # storage data_ptr and throw an error if the dtype is ever different.
- # TODO: This feature could be added in the future
- storage_dtypes: Dict[int, torch.dtype] = {}
- def persistent_id(obj: Any) -> Optional[Tuple]:
- # FIXME: the docs say that persistent_id should only return a string
- # but torch store returns tuples. This works only in the binary protocol
- # see
- # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
- # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
- if isinstance(obj, type) and issubclass(obj, nn.Module):
- if obj in serialized_container_types:
- return None
- serialized_container_types[obj] = True
- source_file = source = None
- try:
- source_lines, _, source_file = get_source_lines_and_file(obj)
- source = ''.join(source_lines)
- except Exception: # saving the source is optional, so we can ignore any errors
- warnings.warn("Couldn't retrieve source code for container of "
- "type " + obj.__name__ + ". It won't be checked "
- "for correctness upon loading.")
- return ('module', obj, source_file, source)
- if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
- storage: torch.UntypedStorage
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, this case
- # can be deleted
- storage = obj._untyped_storage
- storage_dtype = obj.dtype
- storage_type_str = obj._pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- dtype = obj.dtype
- storage_numel = obj._size()
- elif isinstance(obj, torch.UntypedStorage):
- storage = obj
- storage_dtype = torch.uint8
- storage_type = normalize_storage_type(type(obj))
- dtype = torch.uint8
- storage_numel = storage.nbytes()
- else:
- raise TypeError(f'type not recognized: {type(obj)}')
- # If storage is allocated, ensure that any other saved storages
- # pointing to the same data all have the same dtype. If storage is
- # not allocated, don't perform this check
- if storage.data_ptr() != 0:
- if storage.data_ptr() in storage_dtypes:
- if storage_dtype != storage_dtypes[storage.data_ptr()]:
- raise RuntimeError(
- 'Cannot save multiple tensors or storages that '
- 'view the same data as different types')
- else:
- storage_dtypes[storage.data_ptr()] = storage_dtype
- view_metadata: Optional[Tuple[str, int, int]]
- # Offset is always 0, but we keep it for backwards compatibility
- # with the old serialization format (which supported storage views)
- offset = 0
- storage_key = str(storage._cdata)
- location = location_tag(storage)
- # TODO: There's an issue here with FC. It might be impossible to
- # solve, but it's worth noting. Imagine we save a list `[storage,
- # tensor]`, where `tensor.storage()` is the same as `storage`, and
- # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
- # torch.float`. The storage will be serialized with element size
- # of 1, since we're choosing to serialize the first occurance of
- # a duplicate storage. Since this legacy serialization format saves
- # the numel of the storage, rather than nbytes directly, we'll be
- # effectively saving nbytes in this case. We'll be able to load it
- # and the tensor back up with no problems in _this_ and future
- # versions of pytorch, but in older versions, here's the problem:
- # the storage will be loaded up as a UntypedStorage, and then the
- # FloatTensor will loaded and the UntypedStorage will be assigned to
- # it. Since the storage dtype does not match the tensor dtype, this
- # will cause an error. If we reverse the list, like `[tensor,
- # storage]`, then we will save the `tensor.storage()` as a faked
- # `FloatStorage`, and the saved size will be the correct
- # dtype-specific numel count that old versions expect. `tensor`
- # will be able to load up properly in old versions, pointing to
- # a FloatStorage. However, `storage` is still being translated to
- # a UntypedStorage, and it will try to resolve to the same
- # FloatStorage that `tensor` contains. This will also cause an
- # error. It doesn't seem like there's any way around this.
- # Probably, we just cannot maintain FC for the legacy format if the
- # saved list contains both a tensor and a storage that point to the
- # same data. We should still be able to maintain FC for lists of
- # just tensors, as long as all views share the same dtype as the
- # tensor they are viewing.
- if storage_key not in serialized_storages:
- serialized_storages[storage_key] = (storage, dtype)
- is_view = storage._cdata != storage._cdata
- if is_view:
- view_metadata = (str(storage._cdata), offset, storage.nbytes())
- else:
- view_metadata = None
- res = ('storage',
- storage_type,
- storage_key,
- location,
- storage_numel,
- view_metadata)
- return res
- return None
- sys_info = dict(
- protocol_version=PROTOCOL_VERSION,
- little_endian=sys.byteorder == 'little',
- type_sizes=dict(
- short=SHORT_SIZE,
- int=INT_SIZE,
- long=LONG_SIZE,
- ),
- )
- pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
- pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
- pickle_module.dump(sys_info, f, protocol=pickle_protocol)
- pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
- pickler.persistent_id = persistent_id
- pickler.dump(obj)
- serialized_storage_keys = sorted(serialized_storages.keys())
- pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
- f.flush()
- for key in serialized_storage_keys:
- storage, dtype = serialized_storages[key]
- storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
- def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
- serialized_storages = {}
- id_map: Dict[int, str] = {}
- # Since loading storages that view the same data with different dtypes is
- # not supported, we need to keep track of the dtype associated with each
- # storage data_ptr and throw an error if the dtype is ever different.
- # TODO: This feature could be added in the future
- storage_dtypes: Dict[int, torch.dtype] = {}
- def persistent_id(obj):
- # FIXME: the docs say that persistent_id should only return a string
- # but torch store returns tuples. This works only in the binary protocol
- # see
- # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
- # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
- if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, this case
- # can be deleted
- storage = obj._untyped_storage
- storage_dtype = obj.dtype
- storage_type_str = obj._pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- storage_numel = obj._size()
- else:
- storage = obj
- storage_dtype = torch.uint8
- storage_type = normalize_storage_type(type(obj))
- storage_numel = storage.nbytes()
- # If storage is allocated, ensure that any other saved storages
- # pointing to the same data all have the same dtype. If storage is
- # not allocated, don't perform this check
- if storage.data_ptr() != 0:
- if storage.data_ptr() in storage_dtypes:
- if storage_dtype != storage_dtypes[storage.data_ptr()]:
- raise RuntimeError(
- 'Cannot save multiple tensors or storages that '
- 'view the same data as different types')
- else:
- storage_dtypes[storage.data_ptr()] = storage_dtype
- storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
- location = location_tag(storage)
- serialized_storages[storage_key] = storage
- return ('storage',
- storage_type,
- storage_key,
- location,
- storage_numel)
- return None
- # Write the pickle data for `obj`
- data_buf = io.BytesIO()
- pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
- pickler.persistent_id = persistent_id
- pickler.dump(obj)
- data_value = data_buf.getvalue()
- zip_file.write_record('data.pkl', data_value, len(data_value))
- # Write byte order marker
- if not _disable_byteorder_record:
- if sys.byteorder not in ['little', 'big']:
- raise ValueError('Unknown endianness type: ' + sys.byteorder)
- zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
- # Write each tensor to a file named tensor/the_tensor_key in the zip archive
- for key in sorted(serialized_storages.keys()):
- name = f'data/{key}'
- storage = serialized_storages[key]
- # given that we copy things around anyway, we might use storage.cpu()
- # this means to that to get tensors serialized, you need to implement
- # .cpu() on the underlying Storage
- if storage.device.type != 'cpu':
- storage = storage.cpu()
- # Now that it is on the CPU we can directly copy it into the zip file
- num_bytes = storage.nbytes()
- zip_file.write_record(name, storage, num_bytes)
- def load(
- f: FILE_LIKE,
- map_location: MAP_LOCATION = None,
- pickle_module: Any = None,
- *,
- weights_only: Optional[bool] = None,
- mmap: Optional[bool] = None,
- **pickle_load_args: Any
- ) -> Any:
- # Reference: https://github.com/pytorch/pytorch/issues/54354
- # The first line of this docstring overrides the one Sphinx generates for the
- # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
- # the build environment (e.g. `<module 'pickle' from '/leaked/path').
- """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
- Loads an object saved with :func:`torch.save` from a file.
- :func:`torch.load` uses Python's unpickling facilities but treats storages,
- which underlie tensors, specially. They are first deserialized on the
- CPU and are then moved to the device they were saved from. If this fails
- (e.g. because the run time system doesn't have certain devices), an exception
- is raised. However, storages can be dynamically remapped to an alternative
- set of devices using the :attr:`map_location` argument.
- If :attr:`map_location` is a callable, it will be called once for each serialized
- storage with two arguments: storage and location. The storage argument
- will be the initial deserialization of the storage, residing on the CPU.
- Each serialized storage has a location tag associated with it which
- identifies the device it was saved from, and this tag is the second
- argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
- for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
- :attr:`map_location` should return either ``None`` or a storage. If
- :attr:`map_location` returns a storage, it will be used as the final deserialized
- object, already moved to the right device. Otherwise, :func:`torch.load` will
- fall back to the default behavior, as if :attr:`map_location` wasn't specified.
- If :attr:`map_location` is a :class:`torch.device` object or a string containing
- a device tag, it indicates the location where all tensors should be loaded.
- Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
- appearing in the file (keys), to ones that specify where to put the
- storages (values).
- User extensions can register their own location tags and tagging and
- deserialization methods using :func:`torch.serialization.register_package`.
- Args:
- f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
- or a string or os.PathLike object containing a file name
- map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
- locations
- pickle_module: module used for unpickling metadata and objects (has to
- match the :attr:`pickle_module` used to serialize file)
- weights_only: Indicates whether unpickler should be restricted to
- loading only tensors, primitive types, dictionaries
- and any types added via :func:`torch.serialization.add_safe_globals`.
- mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
- Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
- are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
- second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
- tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
- pickle_load_args: (Python 3 only) optional keyword arguments passed over to
- :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
- :attr:`errors=...`.
- .. warning::
- :func:`torch.load()` unless `weights_only` parameter is set to `True`,
- uses ``pickle`` module implicitly, which is known to be insecure.
- It is possible to construct malicious pickle data which will execute arbitrary code
- during unpickling. Never load data that could have come from an untrusted
- source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
- .. note::
- When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
- will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
- and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
- .. note::
- By default, we decode byte strings as ``utf-8``. This is to avoid a common error
- case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
- when loading files saved by Python 2 in Python 3. If this default
- is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
- these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
- to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
- as byte arrays which can be decoded later with ``byte_array.decode(...)``.
- Example:
- >>> # xdoctest: +SKIP("undefined filepaths")
- >>> torch.load('tensors.pt', weights_only=True)
- # Load all tensors onto the CPU
- >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
- # Load all tensors onto the CPU, using a function
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
- # Load all tensors onto GPU 1
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
- # Map tensors from GPU 1 to GPU 0
- >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
- # Load tensor from io.BytesIO object
- # Loading from a buffer setting weights_only=False, warning this can be unsafe
- >>> with open('tensor.pt', 'rb') as f:
- ... buffer = io.BytesIO(f.read())
- >>> torch.load(buffer, weights_only=False)
- # Load a module with 'ascii' encoding for unpickling
- # Loading from a module setting weights_only=False, warning this can be unsafe
- >>> torch.load('module.pt', encoding='ascii', weights_only=False)
- """
- torch._C._log_api_usage_once("torch.load")
- UNSAFE_MESSAGE = (
- "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
- "but it can result in arbitrary code execution. Do it only if you got the file from a "
- "trusted source."
- )
- DOCS_MESSAGE = (
- "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
- "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
- )
- def _get_wo_message(message: str) -> str:
- pattern = r"GLOBAL (\S+) was not an allowed global by default."
- has_unsafe_global = re.search(pattern, message) is not None
- if has_unsafe_global:
- updated_message = (
- "Weights only load failed. This file can still be loaded, to do so you have two options "
- f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
- "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
- + message
- )
- else:
- updated_message = (
- f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
- "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
- "error: " + message
- )
- return updated_message + DOCS_MESSAGE
- if weights_only is None:
- weights_only, warn_weights_only = False, True
- else:
- warn_weights_only = False
- # Add ability to force safe only weight loads via environment variable
- if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
- weights_only = True
- if weights_only:
- if pickle_module is not None:
- raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
- else:
- if pickle_module is None:
- if warn_weights_only:
- warnings.warn(
- "You are using `torch.load` with `weights_only=False` (the current default value), which uses "
- "the default pickle module implicitly. It is possible to construct malicious pickle data "
- "which will execute arbitrary code during unpickling (See "
- "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
- "In a future release, the default value for `weights_only` will be flipped to `True`. This "
- "limits the functions that could be executed during unpickling. Arbitrary objects will no "
- "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
- "user via `torch.serialization.add_safe_globals`. We recommend you start setting "
- "`weights_only=True` for any use case where you don't have full control of the loaded file. "
- "Please open an issue on GitHub for any issues related to this experimental feature.",
- FutureWarning,
- stacklevel=2,
- )
- pickle_module = pickle
- # make flipping default BC-compatible
- if mmap is None:
- mmap = False
- _check_dill_version(pickle_module)
- if 'encoding' not in pickle_load_args.keys():
- pickle_load_args['encoding'] = 'utf-8'
- with _open_file_like(f, 'rb') as opened_file:
- if _is_zipfile(opened_file):
- # The zipfile reader is going to advance the current file position.
- # If we want to actually tail call to torch.jit.load, we need to
- # reset back to the original position.
- orig_position = opened_file.tell()
- overall_storage = None
- with _open_zipfile_reader(opened_file) as opened_zipfile:
- if _is_torchscript_zip(opened_zipfile):
- warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
- " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
- " silence this warning)", UserWarning)
- opened_file.seek(orig_position)
- return torch.jit.load(opened_file, map_location=map_location)
- if mmap:
- if not _is_path(f):
- raise ValueError("f must be a file path in order to use the mmap argument")
- size = os.path.getsize(f)
- if not IS_WINDOWS:
- shared = get_default_mmap_options() == MAP_SHARED
- else:
- shared = False
- overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
- if weights_only:
- try:
- return _load(opened_zipfile,
- map_location,
- _weights_only_unpickler,
- overall_storage=overall_storage,
- **pickle_load_args)
- except RuntimeError as e:
- raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
- return _load(
- opened_zipfile,
- map_location,
- pickle_module,
- overall_storage=overall_storage,
- **pickle_load_args,
- )
- if mmap:
- f_name = "" if not isinstance(f, str) else f"{f}, "
- raise RuntimeError("mmap can only be used with files saved with "
- f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
- "please torch.save your checkpoint with this option in order to use mmap.")
- if weights_only:
- try:
- return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
- except RuntimeError as e:
- raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
- return _legacy_load(
- opened_file, map_location, pickle_module, **pickle_load_args
- )
- # Register pickling support for layout instances such as
- # torch.sparse_coo, etc
- def _get_layout(name):
- """Get layout extension object from its string representation.
- """
- cache = _get_layout.cache # type: ignore[attr-defined]
- if not cache:
- for v in torch.__dict__.values():
- if isinstance(v, torch.layout):
- cache[str(v)] = v
- return cache[name]
- # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
- _get_layout.cache = {} # type: ignore[attr-defined]
- copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
- def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
- deserialized_objects: Dict[int, Any] = {}
- restore_location = _get_restore_location(map_location)
- class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
- def find_class(self, mod_name, name):
- if type(name) is str and 'Storage' in name:
- try:
- return StorageType(name)
- except KeyError:
- pass
- return super().find_class(mod_name, name)
- def _check_container_source(container_type, source_file, original_source):
- try:
- current_source = ''.join(get_source_lines_and_file(container_type)[0])
- except Exception: # saving the source is optional, so we can ignore any errors
- warnings.warn("Couldn't retrieve source code for container of "
- "type " + container_type.__name__ + ". It won't be checked "
- "for correctness upon loading.")
- return
- if original_source != current_source:
- if container_type.dump_patches:
- file_name = container_type.__name__ + '.patch'
- diff = difflib.unified_diff(current_source.split('\n'),
- original_source.split('\n'),
- source_file,
- source_file, lineterm="")
- lines = '\n'.join(diff)
- try:
- with open(file_name, 'a+') as f:
- file_size = f.seek(0, 2)
- f.seek(0)
- if file_size == 0:
- f.write(lines)
- elif file_size != len(lines) or f.read() != lines:
- raise OSError
- msg = ("Saved a reverse patch to " + file_name + ". "
- "Run `patch -p0 < " + file_name + "` to revert your "
- "changes.")
- except OSError:
- msg = ("Tried to save a patch, but couldn't create a "
- "writable file " + file_name + ". Make sure it "
- "doesn't exist and your working directory is "
- "writable.")
- else:
- msg = ("you can retrieve the original source code by "
- "accessing the object's source attribute or set "
- "`torch.nn.Module.dump_patches = True` and use the "
- "patch tool to revert the changes.")
- msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
- warnings.warn(msg, SourceChangeWarning)
- def legacy_load(f):
- deserialized_objects: Dict[int, Any] = {}
- def persistent_load(saved_id):
- if isinstance(saved_id, tuple):
- # Ignore containers that don't have any sources saved
- if all(saved_id[1:]):
- _check_container_source(*saved_id)
- return saved_id[0]
- return deserialized_objects[int(saved_id)]
- with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
- mkdtemp() as tmpdir:
- tar.extract('storages', path=tmpdir)
- with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
- num_storages = pickle_module.load(f, **pickle_load_args)
- for i in range(num_storages):
- args = pickle_module.load(f, **pickle_load_args)
- key, location, storage_type = args
- dtype = storage_type._dtype
- obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
- obj = restore_location(obj, location)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[key] = torch.storage.TypedStorage(
- wrap_storage=obj,
- dtype=dtype,
- _internal=True)
- storage_views = pickle_module.load(f, **pickle_load_args)
- for target_cdata, root_cdata, offset, numel in storage_views:
- root = deserialized_objects[root_cdata]
- element_size = torch._utils._element_size(root.dtype)
- offset_bytes = offset * element_size
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[target_cdata] = torch.storage.TypedStorage(
- wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
- dtype=root.dtype,
- _internal=True)
- tar.extract('tensors', path=tmpdir)
- with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
- num_tensors = pickle_module.load(f, **pickle_load_args)
- for _ in range(num_tensors):
- args = pickle_module.load(f, **pickle_load_args)
- key, storage_id, original_tensor_type = args
- storage = deserialized_objects[storage_id]
- ndim, = struct.unpack('<i', f.read(4))
- # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
- f.read(4)
- numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
- stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
- storage_offset, = struct.unpack('<q', f.read(8))
- tensor = torch.empty((0,), dtype=storage.dtype).set_(
- storage._untyped_storage, storage_offset, numel, stride)
- deserialized_objects[key] = tensor
- pickle_file = tar.extractfile('pickle')
- unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- result = unpickler.load()
- return result
- deserialized_objects = {}
- def persistent_load(saved_id):
- assert isinstance(saved_id, tuple)
- typename = _maybe_decode_ascii(saved_id[0])
- data = saved_id[1:]
- if typename == 'module':
- # Ignore containers that don't have any sources saved
- if all(data[1:]):
- _check_container_source(*data)
- return data[0]
- elif typename == 'storage':
- storage_type, root_key, location, numel, view_metadata = data
- location = _maybe_decode_ascii(location)
- dtype = storage_type.dtype
- nbytes = numel * torch._utils._element_size(dtype)
- if root_key not in deserialized_objects:
- if torch._guards.active_fake_mode() is not None:
- obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
- else:
- obj = cast(Storage, torch.UntypedStorage(nbytes))
- obj._torch_load_uninitialized = True
- obj = restore_location(obj, location)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- typed_storage = torch.storage.TypedStorage(
- wrap_storage=obj,
- dtype=dtype,
- _internal=True)
- deserialized_objects[root_key] = typed_storage
- else:
- typed_storage = deserialized_objects[root_key]
- if typed_storage._data_ptr() == 0:
- typed_storage = torch.storage.TypedStorage(
- device=typed_storage._untyped_storage.device,
- dtype=dtype,
- _internal=True)
- if view_metadata is not None:
- view_key, offset, view_size = view_metadata
- offset_bytes = offset * torch._utils._element_size(dtype)
- view_size_bytes = view_size * torch._utils._element_size(dtype)
- if view_key not in deserialized_objects:
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- deserialized_objects[view_key] = torch.storage.TypedStorage(
- wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
- dtype=dtype,
- _internal=True)
- res = deserialized_objects[view_key]
- else:
- res = typed_storage
- return res
- else:
- raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
- _check_seekable(f)
- f_should_read_directly = _should_read_directly(f)
- if f_should_read_directly and f.tell() == 0:
- # legacy_load requires that f has fileno()
- # only if offset is zero we can attempt the legacy tar file loader
- try:
- return legacy_load(f)
- except tarfile.TarError:
- if _is_zipfile(f):
- # .zip is used for torch.jit.save and will throw an un-pickling error here
- raise RuntimeError(
- f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
- # if not a tarfile, reset file offset and proceed
- f.seek(0)
- if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
- raise RuntimeError(
- "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
- f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
- "functionality.")
- magic_number = pickle_module.load(f, **pickle_load_args)
- if magic_number != MAGIC_NUMBER:
- raise RuntimeError("Invalid magic number; corrupt file?")
- protocol_version = pickle_module.load(f, **pickle_load_args)
- if protocol_version != PROTOCOL_VERSION:
- raise RuntimeError(f"Invalid protocol version: {protocol_version}")
- _sys_info = pickle_module.load(f, **pickle_load_args)
- unpickler = UnpicklerWrapper(f, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- result = unpickler.load()
- deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
- if torch._guards.active_fake_mode() is None:
- offset = f.tell() if f_should_read_directly else None
- for key in deserialized_storage_keys:
- assert key in deserialized_objects
- typed_storage = deserialized_objects[key]
- typed_storage._untyped_storage._set_from_file(
- f, offset, f_should_read_directly,
- torch._utils._element_size(typed_storage.dtype))
- if offset is not None:
- offset = f.tell()
- torch._utils._validate_loaded_sparse_tensors()
- return result
- def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
- # When using encoding='bytes' in Py3, some **internal** keys stored as
- # strings in Py2 are loaded as bytes. This function decodes them with
- # ascii encoding, one that Py3 uses by default.
- #
- # NOTE: This should only be used on internal keys (e.g., `typename` and
- # `location` in `persistent_load` below!
- if isinstance(bytes_str, bytes):
- return bytes_str.decode('ascii')
- return bytes_str
- def _get_restore_location(map_location):
- if map_location is None:
- restore_location = default_restore_location
- elif isinstance(map_location, dict):
- def restore_location(storage, location):
- location = map_location.get(location, location)
- return default_restore_location(storage, location)
- elif isinstance(map_location, (str, bytes)):
- def restore_location(storage, location):
- return default_restore_location(storage, map_location)
- elif isinstance(map_location, torch.device):
- def restore_location(storage, location):
- return default_restore_location(storage, str(map_location))
- else:
- def restore_location(storage, location):
- result = map_location(storage, location)
- if result is None:
- result = default_restore_location(storage, location)
- return result
- return restore_location
- class StorageType:
- def __init__(self, name):
- self._dtype = _get_dtype_from_pickle_storage_type(name)
- @property
- def dtype(self):
- return self._dtype
- def __str__(self):
- return f'StorageType(dtype={self.dtype})'
- def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
- restore_location = _get_restore_location(map_location)
- loaded_storages = {}
- # check if byteswapping is needed
- byteordername = 'byteorder'
- byteorderdata = None
- if zip_file.has_record(byteordername):
- byteorderdata = zip_file.get_record(byteordername)
- if byteorderdata not in [b'little', b'big']:
- raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
- elif get_default_load_endianness() == LoadEndianness.LITTLE or \
- get_default_load_endianness() is None:
- byteorderdata = b'little'
- elif get_default_load_endianness() == LoadEndianness.BIG:
- byteorderdata = b'big'
- elif get_default_load_endianness() == LoadEndianness.NATIVE:
- pass
- else:
- raise ValueError('Invalid load endianness type')
- if not zip_file.has_record(byteordername) and \
- get_default_load_endianness() is None and \
- sys.byteorder == 'big':
- # Default behaviour was changed
- # See https://github.com/pytorch/pytorch/issues/101688
- warnings.warn("The default load endianness for checkpoints without a byteorder mark "
- "on big endian machines was changed from 'native' to 'little' endian, "
- "to avoid this behavior please use "
- "torch.serialization.set_default_load_endianness to set "
- "the desired default load endianness",
- UserWarning)
- def load_tensor(dtype, numel, key, location):
- name = f'data/{key}'
- if torch._guards.detect_fake_mode(None) is not None:
- nbytes = numel * torch._utils._element_size(dtype)
- storage = torch.UntypedStorage(nbytes, device='meta')
- elif overall_storage is not None:
- storage_offset = zip_file.get_record_offset(name)
- storage = overall_storage[storage_offset:storage_offset + numel]
- else:
- storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
- # swap here if byteswapping is needed
- if byteorderdata is not None:
- if byteorderdata.decode() != sys.byteorder:
- storage.byteswap(dtype)
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- typed_storage = torch.storage.TypedStorage(
- wrap_storage=restore_location(storage, location),
- dtype=dtype,
- _internal=True)
- if typed_storage._data_ptr() != 0:
- loaded_storages[key] = typed_storage
- return typed_storage
- def persistent_load(saved_id):
- assert isinstance(saved_id, tuple)
- typename = _maybe_decode_ascii(saved_id[0])
- data = saved_id[1:]
- assert typename == 'storage', \
- f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
- storage_type, key, location, numel = data
- if storage_type is torch.UntypedStorage:
- dtype = torch.uint8
- else:
- dtype = storage_type.dtype
- if key in loaded_storages:
- typed_storage = loaded_storages[key]
- else:
- nbytes = numel * torch._utils._element_size(dtype)
- typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
- return typed_storage
- load_module_mapping: Dict[str, str] = {
- # See https://github.com/pytorch/pytorch/pull/51633
- 'torch.tensor': 'torch._tensor'
- }
- # Need to subclass Unpickler instead of directly monkey-patching the find_class method
- # because it's marked readonly in pickle.
- # The type: ignore is because mypy can't statically determine the type of this class.
- class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
- # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
- # Lets us override the imports that pickle uses when unpickling an object.
- # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
- def find_class(self, mod_name, name):
- if type(name) is str and 'Storage' in name:
- try:
- return StorageType(name)
- except KeyError:
- pass
- mod_name = load_module_mapping.get(mod_name, mod_name)
- return super().find_class(mod_name, name)
- # Load the data (which may in turn use `persistent_load` to load tensors)
- data_file = io.BytesIO(zip_file.get_record(pickle_file))
- unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
- unpickler.persistent_load = persistent_load
- # Needed for tensors where storage device and rebuild tensor device are
- # not connected (wrapper subclasses and tensors rebuilt using numpy)
- torch._utils._thread_local_state.map_location = map_location
- result = unpickler.load()
- del torch._utils._thread_local_state.map_location
- torch._utils._validate_loaded_sparse_tensors()
- torch._C._log_api_usage_metadata(
- "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
- )
- return result
- def _is_torchscript_zip(zip_file):
- return 'constants.pkl' in zip_file.get_all_records()
|