serialization.py 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536
  1. # mypy: allow-untyped-defs
  2. import difflib
  3. import functools
  4. import os
  5. import io
  6. import re
  7. import shutil
  8. import struct
  9. import sys
  10. import torch
  11. import tarfile
  12. import tempfile
  13. import warnings
  14. from contextlib import closing, contextmanager
  15. from enum import Enum
  16. from ._utils import _import_dotted_name
  17. from torch._sources import get_source_lines_and_file
  18. from torch.types import Storage
  19. from torch.storage import _get_dtype_from_pickle_storage_type
  20. from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
  21. from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
  22. import copyreg
  23. import pickle
  24. import torch._weights_only_unpickler as _weights_only_unpickler
  25. DEFAULT_PROTOCOL = 2
  26. LONG_SIZE = struct.Struct('=l').size
  27. INT_SIZE = struct.Struct('=i').size
  28. SHORT_SIZE = struct.Struct('=h').size
  29. MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
  30. PROTOCOL_VERSION = 1001
  31. STORAGE_KEY_SEPARATOR = ','
  32. FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
  33. MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
  34. STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
  35. IS_WINDOWS = sys.platform == "win32"
  36. if not IS_WINDOWS:
  37. from mmap import MAP_SHARED, MAP_PRIVATE
  38. else:
  39. MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
  40. __all__ = [
  41. 'SourceChangeWarning',
  42. 'mkdtemp',
  43. 'register_package',
  44. 'check_module_version_greater_or_equal',
  45. 'validate_cuda_device',
  46. 'validate_hpu_device',
  47. 'location_tag',
  48. 'default_restore_location',
  49. 'normalize_storage_type',
  50. 'storage_to_tensor_type',
  51. 'save',
  52. 'load',
  53. 'StorageType',
  54. 'LoadEndianness',
  55. 'get_default_load_endianness',
  56. 'set_default_load_endianness',
  57. 'clear_safe_globals',
  58. 'get_safe_globals',
  59. 'add_safe_globals',
  60. ]
  61. class SourceChangeWarning(Warning):
  62. pass
  63. @contextmanager
  64. def mkdtemp():
  65. path = tempfile.mkdtemp()
  66. try:
  67. yield path
  68. finally:
  69. shutil.rmtree(path)
  70. _package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
  71. class LoadEndianness(Enum):
  72. NATIVE = 1
  73. LITTLE = 2
  74. BIG = 3
  75. _default_load_endian: Optional[LoadEndianness] = None
  76. def get_default_load_endianness() -> Optional[LoadEndianness]:
  77. '''
  78. Get fallback byte order for loading files
  79. If byteorder mark is not present in saved checkpoint,
  80. this byte order is used as fallback.
  81. By default, it's "native" byte order.
  82. Returns:
  83. default_load_endian: Optional[LoadEndianness]
  84. '''
  85. return _default_load_endian
  86. def set_default_load_endianness(endianness):
  87. '''
  88. Set fallback byte order for loading files
  89. If byteorder mark is not present in saved checkpoint,
  90. this byte order is used as fallback.
  91. By default, it's "native" byte order.
  92. Args:
  93. endianness: the new fallback byte order
  94. '''
  95. global _default_load_endian
  96. if not isinstance(endianness, LoadEndianness) and endianness is not None:
  97. raise TypeError("Invalid argument type in function set_default_load_endianness")
  98. _default_load_endian = endianness
  99. _default_mmap_options: int = MAP_PRIVATE
  100. def get_default_mmap_options() -> int:
  101. '''
  102. Get default mmap options for :func:`torch.load` with ``mmap=True``.
  103. Defaults to ``mmap.MAP_PRIVATE``.
  104. Returns:
  105. default_mmap_options: int
  106. '''
  107. return _default_mmap_options
  108. def set_default_mmap_options(flags: int):
  109. '''
  110. Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
  111. For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
  112. Please open an issue if you need any other option to be added here.
  113. .. note::
  114. This feature is currently not supported for Windows.
  115. Args:
  116. flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
  117. '''
  118. global _default_mmap_options
  119. if IS_WINDOWS:
  120. raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
  121. if (flags != MAP_PRIVATE and flags != MAP_SHARED):
  122. raise ValueError("Invalid argument in function set_default_mmap_options, "
  123. f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
  124. _default_mmap_options = flags
  125. def clear_safe_globals() -> None:
  126. '''
  127. Clears the list of globals that are safe for ``weights_only`` load.
  128. '''
  129. _weights_only_unpickler._clear_safe_globals()
  130. def get_safe_globals() -> List[Any]:
  131. '''
  132. Returns the list of user-added globals that are safe for ``weights_only`` load.
  133. '''
  134. return _weights_only_unpickler._get_safe_globals()
  135. def add_safe_globals(safe_globals: List[Any]) -> None:
  136. '''
  137. Marks the given globals as safe for ``weights_only`` load. For example, functions
  138. added to this list can be called during unpickling, classes could be instantiated
  139. and have state set.
  140. Args:
  141. safe_globals (List[Any]): list of globals to mark as safe
  142. Example:
  143. >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
  144. >>> import tempfile
  145. >>> class MyTensor(torch.Tensor):
  146. ... pass
  147. >>> t = MyTensor(torch.randn(2, 3))
  148. >>> with tempfile.NamedTemporaryFile() as f:
  149. ... torch.save(t, f.name)
  150. # Running `torch.load(f.name, weights_only=True)` will fail with
  151. # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
  152. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
  153. ... torch.serialization.add_safe_globals([MyTensor])
  154. ... torch.load(f.name, weights_only=True)
  155. # MyTensor([[-0.5024, -1.8152, -0.5455],
  156. # [-0.8234, 2.0500, -0.3657]])
  157. '''
  158. _weights_only_unpickler._add_safe_globals(safe_globals)
  159. def _is_zipfile(f) -> bool:
  160. # This is a stricter implementation than zipfile.is_zipfile().
  161. # zipfile.is_zipfile() is True if the magic number appears anywhere in the
  162. # binary. Since we expect the files here to be generated by torch.save or
  163. # torch.jit.save, it's safe to only check the start bytes and avoid
  164. # collisions and assume the zip has only 1 file.
  165. # See bugs.python.org/issue28494.
  166. start = f.tell()
  167. # Read the first few bytes and match against the ZIP file signature
  168. local_header_magic_number = b'PK\x03\x04'
  169. read_bytes = f.read(len(local_header_magic_number))
  170. f.seek(start)
  171. return read_bytes == local_header_magic_number
  172. def register_package(
  173. priority: int,
  174. tagger: Callable[[STORAGE], Optional[str]],
  175. deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
  176. ):
  177. '''
  178. Registers callables for tagging and deserializing storage objects with an associated priority.
  179. Tagging associates a device with a storage object at save time while deserializing moves a
  180. storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
  181. are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
  182. value that is not `None`.
  183. To override the deserialization behavior for a device in the global registry, one can register a
  184. tagger with a higher priority than the existing tagger.
  185. This function can also be used to register a tagger and deserializer for new devices.
  186. Args:
  187. priority: Indicates the priority associated with the tagger and deserializer, where a lower
  188. value indicates higher priority.
  189. tagger: Callable that takes in a storage object and returns its tagged device as a string
  190. or None.
  191. deserializer: Callable that takes in storage object and a device string and returns a storage
  192. object on the appropriate device or None.
  193. Returns:
  194. `None`
  195. Example:
  196. >>> def ipu_tag(obj):
  197. >>> if obj.device.type == 'ipu':
  198. >>> return 'ipu'
  199. >>> def ipu_deserialize(obj, location):
  200. >>> if location.startswith('ipu'):
  201. >>> ipu = getattr(torch, "ipu", None)
  202. >>> assert ipu is not None, "IPU device module is not loaded"
  203. >>> assert torch.ipu.is_available(), "ipu is not available"
  204. >>> return obj.ipu(location)
  205. >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
  206. '''
  207. queue_elem = (priority, tagger, deserializer)
  208. _package_registry.append(queue_elem)
  209. _package_registry.sort()
  210. def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
  211. '''
  212. Check if a module's version satisfies requirements
  213. Usually, a module's version string will be like 'x.y.z', which would be represented
  214. as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
  215. string does not match the given tuple's format up to the length of the tuple, then
  216. error and exit or emit a warning.
  217. Args:
  218. module: the module to check the version of
  219. req_version_tuple: tuple (usually of ints) representing the required version
  220. error_if_malformed: whether we should exit if module version string is malformed
  221. Returns:
  222. requirement_is_met: bool
  223. '''
  224. try:
  225. version_strs = module.__version__.split('.')
  226. # Cast module version fields to match the types of the required version
  227. module_version = tuple(
  228. type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
  229. )
  230. requirement_is_met = module_version >= req_version_tuple
  231. except Exception as e:
  232. message = (
  233. f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
  234. f" with tuple {str(req_version_tuple)}"
  235. )
  236. if error_if_malformed:
  237. raise RuntimeError(message) from e
  238. else:
  239. warnings.warn(message + ', but continuing assuming that requirement is met')
  240. requirement_is_met = True
  241. return requirement_is_met
  242. def _cpu_tag(obj):
  243. if obj.device.type == 'cpu':
  244. return 'cpu'
  245. def _mps_tag(obj):
  246. if obj.device.type == 'mps':
  247. return 'mps'
  248. def _meta_tag(obj):
  249. if obj.device.type == 'meta':
  250. return 'meta'
  251. def _backend_tag(backend_name, obj):
  252. if backend_name == 'privateuse1':
  253. backend_name = torch._C._get_privateuse1_backend_name()
  254. if obj.device.type == backend_name:
  255. if obj.device.index is None:
  256. return backend_name
  257. else:
  258. return backend_name + ':' + str(obj.device.index)
  259. def _cpu_deserialize(obj, location):
  260. if location == 'cpu':
  261. return obj
  262. def _mps_deserialize(obj, location):
  263. if location.startswith('mps'):
  264. return obj.mps()
  265. def _meta_deserialize(obj, location):
  266. if location == 'meta':
  267. return torch.UntypedStorage(obj.nbytes(), device='meta')
  268. def _validate_device(location, backend_name):
  269. '''
  270. Check whether the device index of specified backend is valid
  271. In case of privateuse1 backend, your must first register a device_module for
  272. privateuse1 using torch._register_device_module. Implement the following
  273. methods in device_module like cuda: device_module._utils._get_device_index(location, True),
  274. device_module.device_count().
  275. Args:
  276. location: string of device
  277. backend_name: the backend name or the name of privateuse1, which can be renamed
  278. Returns:
  279. device_index: int
  280. '''
  281. if not hasattr(torch, backend_name):
  282. raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
  283. 'If you are running on a CPU-only machine, '
  284. 'please use torch.load with map_location=torch.device(\'cpu\') '
  285. 'to map your storages to the CPU.')
  286. device_module = getattr(torch, backend_name)
  287. if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
  288. device_index = device_module._utils._get_device_index(location, True)
  289. device = torch.device(backend_name, device_index)
  290. else:
  291. device = torch.device(location)
  292. device_index = device.index if device.index else 0
  293. if hasattr(device_module, 'is_available') and not device_module.is_available():
  294. raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
  295. f'device but torch.{backend_name}.is_available() is False. '
  296. 'If you are running on a CPU-only machine, '
  297. 'please use torch.load with map_location=torch.device(\'cpu\') '
  298. 'to map your storages to the CPU.')
  299. if hasattr(device_module, 'device_count'):
  300. device_count = device_module.device_count()
  301. if device_index >= device_count:
  302. raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
  303. f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
  304. 'Please use torch.load with map_location to map your storages '
  305. 'to an existing device.')
  306. return device
  307. def validate_cuda_device(location):
  308. return _validate_device(location, 'cuda').index
  309. def validate_hpu_device(location):
  310. return _validate_device(location, 'hpu').index
  311. def _deserialize(backend_name, obj, location):
  312. if backend_name == 'privateuse1':
  313. backend_name = torch._C._get_privateuse1_backend_name()
  314. if location.startswith(backend_name):
  315. device = _validate_device(location, backend_name)
  316. return obj.to(device=device)
  317. register_package(10, _cpu_tag, _cpu_deserialize)
  318. register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
  319. register_package(21, _mps_tag, _mps_deserialize)
  320. register_package(22, _meta_tag, _meta_deserialize)
  321. register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
  322. register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
  323. register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
  324. def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
  325. for _, tagger, _ in _package_registry:
  326. location = tagger(storage)
  327. if location:
  328. return location
  329. raise RuntimeError("don't know how to determine data location of "
  330. + torch.typename(storage))
  331. def default_restore_location(storage, location):
  332. for _, _, fn in _package_registry:
  333. result = fn(storage, location)
  334. if result is not None:
  335. return result
  336. raise RuntimeError("don't know how to restore data location of "
  337. + torch.typename(storage) + " (tagged with "
  338. + location + ")")
  339. def normalize_storage_type(storage_type):
  340. return getattr(torch, storage_type.__name__)
  341. def storage_to_tensor_type(storage):
  342. storage_type = type(storage)
  343. module = _import_dotted_name(storage_type.__module__)
  344. return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
  345. def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
  346. return isinstance(name_or_buffer, (str, os.PathLike))
  347. class _opener:
  348. def __init__(self, file_like):
  349. self.file_like = file_like
  350. def __enter__(self):
  351. return self.file_like
  352. def __exit__(self, *args):
  353. pass
  354. class _open_file(_opener):
  355. def __init__(self, name, mode):
  356. super().__init__(open(name, mode))
  357. def __exit__(self, *args):
  358. self.file_like.close()
  359. class _open_buffer_reader(_opener):
  360. def __init__(self, buffer):
  361. super().__init__(buffer)
  362. _check_seekable(buffer)
  363. class _open_buffer_writer(_opener):
  364. def __exit__(self, *args):
  365. self.file_like.flush()
  366. def _open_file_like(name_or_buffer, mode):
  367. if _is_path(name_or_buffer):
  368. return _open_file(name_or_buffer, mode)
  369. else:
  370. if 'w' in mode:
  371. return _open_buffer_writer(name_or_buffer)
  372. elif 'r' in mode:
  373. return _open_buffer_reader(name_or_buffer)
  374. else:
  375. raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
  376. class _open_zipfile_reader(_opener):
  377. def __init__(self, name_or_buffer) -> None:
  378. super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
  379. class _open_zipfile_writer_file(_opener):
  380. def __init__(self, name) -> None:
  381. self.file_stream = None
  382. self.name = str(name)
  383. try:
  384. self.name.encode('ascii')
  385. except UnicodeEncodeError:
  386. # PyTorchFileWriter only supports ascii filename.
  387. # For filenames with non-ascii characters, we rely on Python
  388. # for writing out the file.
  389. self.file_stream = io.FileIO(self.name, mode='w')
  390. super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
  391. else:
  392. super().__init__(torch._C.PyTorchFileWriter(self.name))
  393. def __exit__(self, *args) -> None:
  394. self.file_like.write_end_of_file()
  395. if self.file_stream is not None:
  396. self.file_stream.close()
  397. class _open_zipfile_writer_buffer(_opener):
  398. def __init__(self, buffer) -> None:
  399. if not callable(getattr(buffer, "write", None)):
  400. msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
  401. if not hasattr(buffer, "write"):
  402. raise AttributeError(msg)
  403. raise TypeError(msg)
  404. self.buffer = buffer
  405. super().__init__(torch._C.PyTorchFileWriter(buffer))
  406. def __exit__(self, *args) -> None:
  407. self.file_like.write_end_of_file()
  408. self.buffer.flush()
  409. def _open_zipfile_writer(name_or_buffer):
  410. container: Type[_opener]
  411. if _is_path(name_or_buffer):
  412. container = _open_zipfile_writer_file
  413. else:
  414. container = _open_zipfile_writer_buffer
  415. return container(name_or_buffer)
  416. def _is_compressed_file(f) -> bool:
  417. compress_modules = ['gzip']
  418. try:
  419. return f.__module__ in compress_modules
  420. except AttributeError:
  421. return False
  422. def _should_read_directly(f):
  423. """
  424. Checks if f is a file that should be read directly. It should be read
  425. directly if it is backed by a real file (has a fileno) and is not a
  426. a compressed file (e.g. gzip)
  427. """
  428. if _is_compressed_file(f):
  429. return False
  430. try:
  431. return f.fileno() >= 0
  432. except io.UnsupportedOperation:
  433. return False
  434. except AttributeError:
  435. return False
  436. def _check_seekable(f) -> bool:
  437. def raise_err_msg(patterns, e):
  438. for p in patterns:
  439. if p in str(e):
  440. msg = (str(e) + ". You can only torch.load from a file that is seekable."
  441. + " Please pre-load the data into a buffer like io.BytesIO and"
  442. + " try to load from it instead.")
  443. raise type(e)(msg)
  444. raise e
  445. try:
  446. f.seek(f.tell())
  447. return True
  448. except (io.UnsupportedOperation, AttributeError) as e:
  449. raise_err_msg(["seek", "tell"], e)
  450. return False
  451. def _check_dill_version(pickle_module) -> None:
  452. '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
  453. If dill version is lower than 0.3.1, a ValueError is raised.
  454. Args:
  455. pickle_module: module used for pickling metadata and objects
  456. '''
  457. if pickle_module is not None and pickle_module.__name__ == 'dill':
  458. required_dill_version = (0, 3, 1)
  459. if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
  460. raise ValueError((
  461. "'torch' supports dill >= {}, but you have dill {}."
  462. " Please upgrade dill or switch to 'pickle'"
  463. ).format(
  464. '.'.join([str(num) for num in required_dill_version]),
  465. pickle_module.__version__
  466. ))
  467. def _check_save_filelike(f):
  468. if not _is_path(f) and not hasattr(f, 'write'):
  469. raise AttributeError(
  470. "expected 'f' to be string, path, or a file-like object with "
  471. "a 'write' attribute")
  472. def save(
  473. obj: object,
  474. f: FILE_LIKE,
  475. pickle_module: Any = pickle,
  476. pickle_protocol: int = DEFAULT_PROTOCOL,
  477. _use_new_zipfile_serialization: bool = True,
  478. _disable_byteorder_record: bool = False
  479. ) -> None:
  480. # Reference: https://github.com/pytorch/pytorch/issues/54354
  481. # The first line of this docstring overrides the one Sphinx generates for the
  482. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  483. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  484. """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
  485. Saves an object to a disk file.
  486. See also: :ref:`saving-loading-tensors`
  487. Args:
  488. obj: saved object
  489. f: a file-like object (has to implement write and flush) or a string or
  490. os.PathLike object containing a file name
  491. pickle_module: module used for pickling metadata and objects
  492. pickle_protocol: can be specified to override the default protocol
  493. .. note::
  494. A common PyTorch convention is to save tensors using .pt file extension.
  495. .. note::
  496. PyTorch preserves storage sharing across serialization. See
  497. :ref:`preserve-storage-sharing` for more details.
  498. .. note::
  499. The 1.6 release of PyTorch switched ``torch.save`` to use a new
  500. zipfile-based file format. ``torch.load`` still retains the ability to
  501. load files in the old format. If for any reason you want ``torch.save``
  502. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
  503. Example:
  504. >>> # xdoctest: +SKIP("makes cwd dirty")
  505. >>> # Save to file
  506. >>> x = torch.tensor([0, 1, 2, 3, 4])
  507. >>> torch.save(x, 'tensor.pt')
  508. >>> # Save to io.BytesIO buffer
  509. >>> buffer = io.BytesIO()
  510. >>> torch.save(x, buffer)
  511. """
  512. torch._C._log_api_usage_once("torch.save")
  513. _check_dill_version(pickle_module)
  514. _check_save_filelike(f)
  515. if _use_new_zipfile_serialization:
  516. with _open_zipfile_writer(f) as opened_zipfile:
  517. _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
  518. return
  519. else:
  520. with _open_file_like(f, 'wb') as opened_file:
  521. _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  522. def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
  523. import torch.nn as nn
  524. serialized_container_types = {}
  525. serialized_storages = {}
  526. # Since loading storages that view the same data with different dtypes is
  527. # not supported, we need to keep track of the dtype associated with each
  528. # storage data_ptr and throw an error if the dtype is ever different.
  529. # TODO: This feature could be added in the future
  530. storage_dtypes: Dict[int, torch.dtype] = {}
  531. def persistent_id(obj: Any) -> Optional[Tuple]:
  532. # FIXME: the docs say that persistent_id should only return a string
  533. # but torch store returns tuples. This works only in the binary protocol
  534. # see
  535. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  536. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  537. if isinstance(obj, type) and issubclass(obj, nn.Module):
  538. if obj in serialized_container_types:
  539. return None
  540. serialized_container_types[obj] = True
  541. source_file = source = None
  542. try:
  543. source_lines, _, source_file = get_source_lines_and_file(obj)
  544. source = ''.join(source_lines)
  545. except Exception: # saving the source is optional, so we can ignore any errors
  546. warnings.warn("Couldn't retrieve source code for container of "
  547. "type " + obj.__name__ + ". It won't be checked "
  548. "for correctness upon loading.")
  549. return ('module', obj, source_file, source)
  550. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  551. storage: torch.UntypedStorage
  552. if isinstance(obj, torch.storage.TypedStorage):
  553. # TODO: Once we decide to break serialization FC, this case
  554. # can be deleted
  555. storage = obj._untyped_storage
  556. storage_dtype = obj.dtype
  557. storage_type_str = obj._pickle_storage_type()
  558. storage_type = getattr(torch, storage_type_str)
  559. dtype = obj.dtype
  560. storage_numel = obj._size()
  561. elif isinstance(obj, torch.UntypedStorage):
  562. storage = obj
  563. storage_dtype = torch.uint8
  564. storage_type = normalize_storage_type(type(obj))
  565. dtype = torch.uint8
  566. storage_numel = storage.nbytes()
  567. else:
  568. raise TypeError(f'type not recognized: {type(obj)}')
  569. # If storage is allocated, ensure that any other saved storages
  570. # pointing to the same data all have the same dtype. If storage is
  571. # not allocated, don't perform this check
  572. if storage.data_ptr() != 0:
  573. if storage.data_ptr() in storage_dtypes:
  574. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  575. raise RuntimeError(
  576. 'Cannot save multiple tensors or storages that '
  577. 'view the same data as different types')
  578. else:
  579. storage_dtypes[storage.data_ptr()] = storage_dtype
  580. view_metadata: Optional[Tuple[str, int, int]]
  581. # Offset is always 0, but we keep it for backwards compatibility
  582. # with the old serialization format (which supported storage views)
  583. offset = 0
  584. storage_key = str(storage._cdata)
  585. location = location_tag(storage)
  586. # TODO: There's an issue here with FC. It might be impossible to
  587. # solve, but it's worth noting. Imagine we save a list `[storage,
  588. # tensor]`, where `tensor.storage()` is the same as `storage`, and
  589. # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
  590. # torch.float`. The storage will be serialized with element size
  591. # of 1, since we're choosing to serialize the first occurance of
  592. # a duplicate storage. Since this legacy serialization format saves
  593. # the numel of the storage, rather than nbytes directly, we'll be
  594. # effectively saving nbytes in this case. We'll be able to load it
  595. # and the tensor back up with no problems in _this_ and future
  596. # versions of pytorch, but in older versions, here's the problem:
  597. # the storage will be loaded up as a UntypedStorage, and then the
  598. # FloatTensor will loaded and the UntypedStorage will be assigned to
  599. # it. Since the storage dtype does not match the tensor dtype, this
  600. # will cause an error. If we reverse the list, like `[tensor,
  601. # storage]`, then we will save the `tensor.storage()` as a faked
  602. # `FloatStorage`, and the saved size will be the correct
  603. # dtype-specific numel count that old versions expect. `tensor`
  604. # will be able to load up properly in old versions, pointing to
  605. # a FloatStorage. However, `storage` is still being translated to
  606. # a UntypedStorage, and it will try to resolve to the same
  607. # FloatStorage that `tensor` contains. This will also cause an
  608. # error. It doesn't seem like there's any way around this.
  609. # Probably, we just cannot maintain FC for the legacy format if the
  610. # saved list contains both a tensor and a storage that point to the
  611. # same data. We should still be able to maintain FC for lists of
  612. # just tensors, as long as all views share the same dtype as the
  613. # tensor they are viewing.
  614. if storage_key not in serialized_storages:
  615. serialized_storages[storage_key] = (storage, dtype)
  616. is_view = storage._cdata != storage._cdata
  617. if is_view:
  618. view_metadata = (str(storage._cdata), offset, storage.nbytes())
  619. else:
  620. view_metadata = None
  621. res = ('storage',
  622. storage_type,
  623. storage_key,
  624. location,
  625. storage_numel,
  626. view_metadata)
  627. return res
  628. return None
  629. sys_info = dict(
  630. protocol_version=PROTOCOL_VERSION,
  631. little_endian=sys.byteorder == 'little',
  632. type_sizes=dict(
  633. short=SHORT_SIZE,
  634. int=INT_SIZE,
  635. long=LONG_SIZE,
  636. ),
  637. )
  638. pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
  639. pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
  640. pickle_module.dump(sys_info, f, protocol=pickle_protocol)
  641. pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
  642. pickler.persistent_id = persistent_id
  643. pickler.dump(obj)
  644. serialized_storage_keys = sorted(serialized_storages.keys())
  645. pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
  646. f.flush()
  647. for key in serialized_storage_keys:
  648. storage, dtype = serialized_storages[key]
  649. storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
  650. def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
  651. serialized_storages = {}
  652. id_map: Dict[int, str] = {}
  653. # Since loading storages that view the same data with different dtypes is
  654. # not supported, we need to keep track of the dtype associated with each
  655. # storage data_ptr and throw an error if the dtype is ever different.
  656. # TODO: This feature could be added in the future
  657. storage_dtypes: Dict[int, torch.dtype] = {}
  658. def persistent_id(obj):
  659. # FIXME: the docs say that persistent_id should only return a string
  660. # but torch store returns tuples. This works only in the binary protocol
  661. # see
  662. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  663. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  664. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  665. if isinstance(obj, torch.storage.TypedStorage):
  666. # TODO: Once we decide to break serialization FC, this case
  667. # can be deleted
  668. storage = obj._untyped_storage
  669. storage_dtype = obj.dtype
  670. storage_type_str = obj._pickle_storage_type()
  671. storage_type = getattr(torch, storage_type_str)
  672. storage_numel = obj._size()
  673. else:
  674. storage = obj
  675. storage_dtype = torch.uint8
  676. storage_type = normalize_storage_type(type(obj))
  677. storage_numel = storage.nbytes()
  678. # If storage is allocated, ensure that any other saved storages
  679. # pointing to the same data all have the same dtype. If storage is
  680. # not allocated, don't perform this check
  681. if storage.data_ptr() != 0:
  682. if storage.data_ptr() in storage_dtypes:
  683. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  684. raise RuntimeError(
  685. 'Cannot save multiple tensors or storages that '
  686. 'view the same data as different types')
  687. else:
  688. storage_dtypes[storage.data_ptr()] = storage_dtype
  689. storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  690. location = location_tag(storage)
  691. serialized_storages[storage_key] = storage
  692. return ('storage',
  693. storage_type,
  694. storage_key,
  695. location,
  696. storage_numel)
  697. return None
  698. # Write the pickle data for `obj`
  699. data_buf = io.BytesIO()
  700. pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
  701. pickler.persistent_id = persistent_id
  702. pickler.dump(obj)
  703. data_value = data_buf.getvalue()
  704. zip_file.write_record('data.pkl', data_value, len(data_value))
  705. # Write byte order marker
  706. if not _disable_byteorder_record:
  707. if sys.byteorder not in ['little', 'big']:
  708. raise ValueError('Unknown endianness type: ' + sys.byteorder)
  709. zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
  710. # Write each tensor to a file named tensor/the_tensor_key in the zip archive
  711. for key in sorted(serialized_storages.keys()):
  712. name = f'data/{key}'
  713. storage = serialized_storages[key]
  714. # given that we copy things around anyway, we might use storage.cpu()
  715. # this means to that to get tensors serialized, you need to implement
  716. # .cpu() on the underlying Storage
  717. if storage.device.type != 'cpu':
  718. storage = storage.cpu()
  719. # Now that it is on the CPU we can directly copy it into the zip file
  720. num_bytes = storage.nbytes()
  721. zip_file.write_record(name, storage, num_bytes)
  722. def load(
  723. f: FILE_LIKE,
  724. map_location: MAP_LOCATION = None,
  725. pickle_module: Any = None,
  726. *,
  727. weights_only: Optional[bool] = None,
  728. mmap: Optional[bool] = None,
  729. **pickle_load_args: Any
  730. ) -> Any:
  731. # Reference: https://github.com/pytorch/pytorch/issues/54354
  732. # The first line of this docstring overrides the one Sphinx generates for the
  733. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  734. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  735. """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
  736. Loads an object saved with :func:`torch.save` from a file.
  737. :func:`torch.load` uses Python's unpickling facilities but treats storages,
  738. which underlie tensors, specially. They are first deserialized on the
  739. CPU and are then moved to the device they were saved from. If this fails
  740. (e.g. because the run time system doesn't have certain devices), an exception
  741. is raised. However, storages can be dynamically remapped to an alternative
  742. set of devices using the :attr:`map_location` argument.
  743. If :attr:`map_location` is a callable, it will be called once for each serialized
  744. storage with two arguments: storage and location. The storage argument
  745. will be the initial deserialization of the storage, residing on the CPU.
  746. Each serialized storage has a location tag associated with it which
  747. identifies the device it was saved from, and this tag is the second
  748. argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
  749. for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
  750. :attr:`map_location` should return either ``None`` or a storage. If
  751. :attr:`map_location` returns a storage, it will be used as the final deserialized
  752. object, already moved to the right device. Otherwise, :func:`torch.load` will
  753. fall back to the default behavior, as if :attr:`map_location` wasn't specified.
  754. If :attr:`map_location` is a :class:`torch.device` object or a string containing
  755. a device tag, it indicates the location where all tensors should be loaded.
  756. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
  757. appearing in the file (keys), to ones that specify where to put the
  758. storages (values).
  759. User extensions can register their own location tags and tagging and
  760. deserialization methods using :func:`torch.serialization.register_package`.
  761. Args:
  762. f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  763. or a string or os.PathLike object containing a file name
  764. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
  765. locations
  766. pickle_module: module used for unpickling metadata and objects (has to
  767. match the :attr:`pickle_module` used to serialize file)
  768. weights_only: Indicates whether unpickler should be restricted to
  769. loading only tensors, primitive types, dictionaries
  770. and any types added via :func:`torch.serialization.add_safe_globals`.
  771. mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
  772. Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
  773. are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
  774. second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
  775. tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
  776. pickle_load_args: (Python 3 only) optional keyword arguments passed over to
  777. :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
  778. :attr:`errors=...`.
  779. .. warning::
  780. :func:`torch.load()` unless `weights_only` parameter is set to `True`,
  781. uses ``pickle`` module implicitly, which is known to be insecure.
  782. It is possible to construct malicious pickle data which will execute arbitrary code
  783. during unpickling. Never load data that could have come from an untrusted
  784. source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
  785. .. note::
  786. When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
  787. will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
  788. and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
  789. .. note::
  790. By default, we decode byte strings as ``utf-8``. This is to avoid a common error
  791. case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
  792. when loading files saved by Python 2 in Python 3. If this default
  793. is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
  794. these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
  795. to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
  796. as byte arrays which can be decoded later with ``byte_array.decode(...)``.
  797. Example:
  798. >>> # xdoctest: +SKIP("undefined filepaths")
  799. >>> torch.load('tensors.pt', weights_only=True)
  800. # Load all tensors onto the CPU
  801. >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
  802. # Load all tensors onto the CPU, using a function
  803. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
  804. # Load all tensors onto GPU 1
  805. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
  806. # Map tensors from GPU 1 to GPU 0
  807. >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
  808. # Load tensor from io.BytesIO object
  809. # Loading from a buffer setting weights_only=False, warning this can be unsafe
  810. >>> with open('tensor.pt', 'rb') as f:
  811. ... buffer = io.BytesIO(f.read())
  812. >>> torch.load(buffer, weights_only=False)
  813. # Load a module with 'ascii' encoding for unpickling
  814. # Loading from a module setting weights_only=False, warning this can be unsafe
  815. >>> torch.load('module.pt', encoding='ascii', weights_only=False)
  816. """
  817. torch._C._log_api_usage_once("torch.load")
  818. UNSAFE_MESSAGE = (
  819. "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
  820. "but it can result in arbitrary code execution. Do it only if you got the file from a "
  821. "trusted source."
  822. )
  823. DOCS_MESSAGE = (
  824. "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
  825. "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
  826. )
  827. def _get_wo_message(message: str) -> str:
  828. pattern = r"GLOBAL (\S+) was not an allowed global by default."
  829. has_unsafe_global = re.search(pattern, message) is not None
  830. if has_unsafe_global:
  831. updated_message = (
  832. "Weights only load failed. This file can still be loaded, to do so you have two options "
  833. f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
  834. "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
  835. + message
  836. )
  837. else:
  838. updated_message = (
  839. f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
  840. "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
  841. "error: " + message
  842. )
  843. return updated_message + DOCS_MESSAGE
  844. if weights_only is None:
  845. weights_only, warn_weights_only = False, True
  846. else:
  847. warn_weights_only = False
  848. # Add ability to force safe only weight loads via environment variable
  849. if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
  850. weights_only = True
  851. if weights_only:
  852. if pickle_module is not None:
  853. raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
  854. else:
  855. if pickle_module is None:
  856. if warn_weights_only:
  857. warnings.warn(
  858. "You are using `torch.load` with `weights_only=False` (the current default value), which uses "
  859. "the default pickle module implicitly. It is possible to construct malicious pickle data "
  860. "which will execute arbitrary code during unpickling (See "
  861. "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
  862. "In a future release, the default value for `weights_only` will be flipped to `True`. This "
  863. "limits the functions that could be executed during unpickling. Arbitrary objects will no "
  864. "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
  865. "user via `torch.serialization.add_safe_globals`. We recommend you start setting "
  866. "`weights_only=True` for any use case where you don't have full control of the loaded file. "
  867. "Please open an issue on GitHub for any issues related to this experimental feature.",
  868. FutureWarning,
  869. stacklevel=2,
  870. )
  871. pickle_module = pickle
  872. # make flipping default BC-compatible
  873. if mmap is None:
  874. mmap = False
  875. _check_dill_version(pickle_module)
  876. if 'encoding' not in pickle_load_args.keys():
  877. pickle_load_args['encoding'] = 'utf-8'
  878. with _open_file_like(f, 'rb') as opened_file:
  879. if _is_zipfile(opened_file):
  880. # The zipfile reader is going to advance the current file position.
  881. # If we want to actually tail call to torch.jit.load, we need to
  882. # reset back to the original position.
  883. orig_position = opened_file.tell()
  884. overall_storage = None
  885. with _open_zipfile_reader(opened_file) as opened_zipfile:
  886. if _is_torchscript_zip(opened_zipfile):
  887. warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
  888. " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
  889. " silence this warning)", UserWarning)
  890. opened_file.seek(orig_position)
  891. return torch.jit.load(opened_file, map_location=map_location)
  892. if mmap:
  893. if not _is_path(f):
  894. raise ValueError("f must be a file path in order to use the mmap argument")
  895. size = os.path.getsize(f)
  896. if not IS_WINDOWS:
  897. shared = get_default_mmap_options() == MAP_SHARED
  898. else:
  899. shared = False
  900. overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
  901. if weights_only:
  902. try:
  903. return _load(opened_zipfile,
  904. map_location,
  905. _weights_only_unpickler,
  906. overall_storage=overall_storage,
  907. **pickle_load_args)
  908. except RuntimeError as e:
  909. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  910. return _load(
  911. opened_zipfile,
  912. map_location,
  913. pickle_module,
  914. overall_storage=overall_storage,
  915. **pickle_load_args,
  916. )
  917. if mmap:
  918. f_name = "" if not isinstance(f, str) else f"{f}, "
  919. raise RuntimeError("mmap can only be used with files saved with "
  920. f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
  921. "please torch.save your checkpoint with this option in order to use mmap.")
  922. if weights_only:
  923. try:
  924. return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
  925. except RuntimeError as e:
  926. raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
  927. return _legacy_load(
  928. opened_file, map_location, pickle_module, **pickle_load_args
  929. )
  930. # Register pickling support for layout instances such as
  931. # torch.sparse_coo, etc
  932. def _get_layout(name):
  933. """Get layout extension object from its string representation.
  934. """
  935. cache = _get_layout.cache # type: ignore[attr-defined]
  936. if not cache:
  937. for v in torch.__dict__.values():
  938. if isinstance(v, torch.layout):
  939. cache[str(v)] = v
  940. return cache[name]
  941. # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
  942. _get_layout.cache = {} # type: ignore[attr-defined]
  943. copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
  944. def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
  945. deserialized_objects: Dict[int, Any] = {}
  946. restore_location = _get_restore_location(map_location)
  947. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  948. def find_class(self, mod_name, name):
  949. if type(name) is str and 'Storage' in name:
  950. try:
  951. return StorageType(name)
  952. except KeyError:
  953. pass
  954. return super().find_class(mod_name, name)
  955. def _check_container_source(container_type, source_file, original_source):
  956. try:
  957. current_source = ''.join(get_source_lines_and_file(container_type)[0])
  958. except Exception: # saving the source is optional, so we can ignore any errors
  959. warnings.warn("Couldn't retrieve source code for container of "
  960. "type " + container_type.__name__ + ". It won't be checked "
  961. "for correctness upon loading.")
  962. return
  963. if original_source != current_source:
  964. if container_type.dump_patches:
  965. file_name = container_type.__name__ + '.patch'
  966. diff = difflib.unified_diff(current_source.split('\n'),
  967. original_source.split('\n'),
  968. source_file,
  969. source_file, lineterm="")
  970. lines = '\n'.join(diff)
  971. try:
  972. with open(file_name, 'a+') as f:
  973. file_size = f.seek(0, 2)
  974. f.seek(0)
  975. if file_size == 0:
  976. f.write(lines)
  977. elif file_size != len(lines) or f.read() != lines:
  978. raise OSError
  979. msg = ("Saved a reverse patch to " + file_name + ". "
  980. "Run `patch -p0 < " + file_name + "` to revert your "
  981. "changes.")
  982. except OSError:
  983. msg = ("Tried to save a patch, but couldn't create a "
  984. "writable file " + file_name + ". Make sure it "
  985. "doesn't exist and your working directory is "
  986. "writable.")
  987. else:
  988. msg = ("you can retrieve the original source code by "
  989. "accessing the object's source attribute or set "
  990. "`torch.nn.Module.dump_patches = True` and use the "
  991. "patch tool to revert the changes.")
  992. msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
  993. warnings.warn(msg, SourceChangeWarning)
  994. def legacy_load(f):
  995. deserialized_objects: Dict[int, Any] = {}
  996. def persistent_load(saved_id):
  997. if isinstance(saved_id, tuple):
  998. # Ignore containers that don't have any sources saved
  999. if all(saved_id[1:]):
  1000. _check_container_source(*saved_id)
  1001. return saved_id[0]
  1002. return deserialized_objects[int(saved_id)]
  1003. with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
  1004. mkdtemp() as tmpdir:
  1005. tar.extract('storages', path=tmpdir)
  1006. with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
  1007. num_storages = pickle_module.load(f, **pickle_load_args)
  1008. for i in range(num_storages):
  1009. args = pickle_module.load(f, **pickle_load_args)
  1010. key, location, storage_type = args
  1011. dtype = storage_type._dtype
  1012. obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
  1013. obj = restore_location(obj, location)
  1014. # TODO: Once we decide to break serialization FC, we can
  1015. # stop wrapping with TypedStorage
  1016. deserialized_objects[key] = torch.storage.TypedStorage(
  1017. wrap_storage=obj,
  1018. dtype=dtype,
  1019. _internal=True)
  1020. storage_views = pickle_module.load(f, **pickle_load_args)
  1021. for target_cdata, root_cdata, offset, numel in storage_views:
  1022. root = deserialized_objects[root_cdata]
  1023. element_size = torch._utils._element_size(root.dtype)
  1024. offset_bytes = offset * element_size
  1025. # TODO: Once we decide to break serialization FC, we can
  1026. # stop wrapping with TypedStorage
  1027. deserialized_objects[target_cdata] = torch.storage.TypedStorage(
  1028. wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
  1029. dtype=root.dtype,
  1030. _internal=True)
  1031. tar.extract('tensors', path=tmpdir)
  1032. with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
  1033. num_tensors = pickle_module.load(f, **pickle_load_args)
  1034. for _ in range(num_tensors):
  1035. args = pickle_module.load(f, **pickle_load_args)
  1036. key, storage_id, original_tensor_type = args
  1037. storage = deserialized_objects[storage_id]
  1038. ndim, = struct.unpack('<i', f.read(4))
  1039. # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
  1040. f.read(4)
  1041. numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  1042. stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  1043. storage_offset, = struct.unpack('<q', f.read(8))
  1044. tensor = torch.empty((0,), dtype=storage.dtype).set_(
  1045. storage._untyped_storage, storage_offset, numel, stride)
  1046. deserialized_objects[key] = tensor
  1047. pickle_file = tar.extractfile('pickle')
  1048. unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
  1049. unpickler.persistent_load = persistent_load
  1050. result = unpickler.load()
  1051. return result
  1052. deserialized_objects = {}
  1053. def persistent_load(saved_id):
  1054. assert isinstance(saved_id, tuple)
  1055. typename = _maybe_decode_ascii(saved_id[0])
  1056. data = saved_id[1:]
  1057. if typename == 'module':
  1058. # Ignore containers that don't have any sources saved
  1059. if all(data[1:]):
  1060. _check_container_source(*data)
  1061. return data[0]
  1062. elif typename == 'storage':
  1063. storage_type, root_key, location, numel, view_metadata = data
  1064. location = _maybe_decode_ascii(location)
  1065. dtype = storage_type.dtype
  1066. nbytes = numel * torch._utils._element_size(dtype)
  1067. if root_key not in deserialized_objects:
  1068. if torch._guards.active_fake_mode() is not None:
  1069. obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
  1070. else:
  1071. obj = cast(Storage, torch.UntypedStorage(nbytes))
  1072. obj._torch_load_uninitialized = True
  1073. obj = restore_location(obj, location)
  1074. # TODO: Once we decide to break serialization FC, we can
  1075. # stop wrapping with TypedStorage
  1076. typed_storage = torch.storage.TypedStorage(
  1077. wrap_storage=obj,
  1078. dtype=dtype,
  1079. _internal=True)
  1080. deserialized_objects[root_key] = typed_storage
  1081. else:
  1082. typed_storage = deserialized_objects[root_key]
  1083. if typed_storage._data_ptr() == 0:
  1084. typed_storage = torch.storage.TypedStorage(
  1085. device=typed_storage._untyped_storage.device,
  1086. dtype=dtype,
  1087. _internal=True)
  1088. if view_metadata is not None:
  1089. view_key, offset, view_size = view_metadata
  1090. offset_bytes = offset * torch._utils._element_size(dtype)
  1091. view_size_bytes = view_size * torch._utils._element_size(dtype)
  1092. if view_key not in deserialized_objects:
  1093. # TODO: Once we decide to break serialization FC, we can
  1094. # stop wrapping with TypedStorage
  1095. deserialized_objects[view_key] = torch.storage.TypedStorage(
  1096. wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
  1097. dtype=dtype,
  1098. _internal=True)
  1099. res = deserialized_objects[view_key]
  1100. else:
  1101. res = typed_storage
  1102. return res
  1103. else:
  1104. raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
  1105. _check_seekable(f)
  1106. f_should_read_directly = _should_read_directly(f)
  1107. if f_should_read_directly and f.tell() == 0:
  1108. # legacy_load requires that f has fileno()
  1109. # only if offset is zero we can attempt the legacy tar file loader
  1110. try:
  1111. return legacy_load(f)
  1112. except tarfile.TarError:
  1113. if _is_zipfile(f):
  1114. # .zip is used for torch.jit.save and will throw an un-pickling error here
  1115. raise RuntimeError(
  1116. f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
  1117. # if not a tarfile, reset file offset and proceed
  1118. f.seek(0)
  1119. if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
  1120. raise RuntimeError(
  1121. "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
  1122. f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
  1123. "functionality.")
  1124. magic_number = pickle_module.load(f, **pickle_load_args)
  1125. if magic_number != MAGIC_NUMBER:
  1126. raise RuntimeError("Invalid magic number; corrupt file?")
  1127. protocol_version = pickle_module.load(f, **pickle_load_args)
  1128. if protocol_version != PROTOCOL_VERSION:
  1129. raise RuntimeError(f"Invalid protocol version: {protocol_version}")
  1130. _sys_info = pickle_module.load(f, **pickle_load_args)
  1131. unpickler = UnpicklerWrapper(f, **pickle_load_args)
  1132. unpickler.persistent_load = persistent_load
  1133. result = unpickler.load()
  1134. deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
  1135. if torch._guards.active_fake_mode() is None:
  1136. offset = f.tell() if f_should_read_directly else None
  1137. for key in deserialized_storage_keys:
  1138. assert key in deserialized_objects
  1139. typed_storage = deserialized_objects[key]
  1140. typed_storage._untyped_storage._set_from_file(
  1141. f, offset, f_should_read_directly,
  1142. torch._utils._element_size(typed_storage.dtype))
  1143. if offset is not None:
  1144. offset = f.tell()
  1145. torch._utils._validate_loaded_sparse_tensors()
  1146. return result
  1147. def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
  1148. # When using encoding='bytes' in Py3, some **internal** keys stored as
  1149. # strings in Py2 are loaded as bytes. This function decodes them with
  1150. # ascii encoding, one that Py3 uses by default.
  1151. #
  1152. # NOTE: This should only be used on internal keys (e.g., `typename` and
  1153. # `location` in `persistent_load` below!
  1154. if isinstance(bytes_str, bytes):
  1155. return bytes_str.decode('ascii')
  1156. return bytes_str
  1157. def _get_restore_location(map_location):
  1158. if map_location is None:
  1159. restore_location = default_restore_location
  1160. elif isinstance(map_location, dict):
  1161. def restore_location(storage, location):
  1162. location = map_location.get(location, location)
  1163. return default_restore_location(storage, location)
  1164. elif isinstance(map_location, (str, bytes)):
  1165. def restore_location(storage, location):
  1166. return default_restore_location(storage, map_location)
  1167. elif isinstance(map_location, torch.device):
  1168. def restore_location(storage, location):
  1169. return default_restore_location(storage, str(map_location))
  1170. else:
  1171. def restore_location(storage, location):
  1172. result = map_location(storage, location)
  1173. if result is None:
  1174. result = default_restore_location(storage, location)
  1175. return result
  1176. return restore_location
  1177. class StorageType:
  1178. def __init__(self, name):
  1179. self._dtype = _get_dtype_from_pickle_storage_type(name)
  1180. @property
  1181. def dtype(self):
  1182. return self._dtype
  1183. def __str__(self):
  1184. return f'StorageType(dtype={self.dtype})'
  1185. def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
  1186. restore_location = _get_restore_location(map_location)
  1187. loaded_storages = {}
  1188. # check if byteswapping is needed
  1189. byteordername = 'byteorder'
  1190. byteorderdata = None
  1191. if zip_file.has_record(byteordername):
  1192. byteorderdata = zip_file.get_record(byteordername)
  1193. if byteorderdata not in [b'little', b'big']:
  1194. raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
  1195. elif get_default_load_endianness() == LoadEndianness.LITTLE or \
  1196. get_default_load_endianness() is None:
  1197. byteorderdata = b'little'
  1198. elif get_default_load_endianness() == LoadEndianness.BIG:
  1199. byteorderdata = b'big'
  1200. elif get_default_load_endianness() == LoadEndianness.NATIVE:
  1201. pass
  1202. else:
  1203. raise ValueError('Invalid load endianness type')
  1204. if not zip_file.has_record(byteordername) and \
  1205. get_default_load_endianness() is None and \
  1206. sys.byteorder == 'big':
  1207. # Default behaviour was changed
  1208. # See https://github.com/pytorch/pytorch/issues/101688
  1209. warnings.warn("The default load endianness for checkpoints without a byteorder mark "
  1210. "on big endian machines was changed from 'native' to 'little' endian, "
  1211. "to avoid this behavior please use "
  1212. "torch.serialization.set_default_load_endianness to set "
  1213. "the desired default load endianness",
  1214. UserWarning)
  1215. def load_tensor(dtype, numel, key, location):
  1216. name = f'data/{key}'
  1217. if torch._guards.detect_fake_mode(None) is not None:
  1218. nbytes = numel * torch._utils._element_size(dtype)
  1219. storage = torch.UntypedStorage(nbytes, device='meta')
  1220. elif overall_storage is not None:
  1221. storage_offset = zip_file.get_record_offset(name)
  1222. storage = overall_storage[storage_offset:storage_offset + numel]
  1223. else:
  1224. storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
  1225. # swap here if byteswapping is needed
  1226. if byteorderdata is not None:
  1227. if byteorderdata.decode() != sys.byteorder:
  1228. storage.byteswap(dtype)
  1229. # TODO: Once we decide to break serialization FC, we can
  1230. # stop wrapping with TypedStorage
  1231. typed_storage = torch.storage.TypedStorage(
  1232. wrap_storage=restore_location(storage, location),
  1233. dtype=dtype,
  1234. _internal=True)
  1235. if typed_storage._data_ptr() != 0:
  1236. loaded_storages[key] = typed_storage
  1237. return typed_storage
  1238. def persistent_load(saved_id):
  1239. assert isinstance(saved_id, tuple)
  1240. typename = _maybe_decode_ascii(saved_id[0])
  1241. data = saved_id[1:]
  1242. assert typename == 'storage', \
  1243. f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
  1244. storage_type, key, location, numel = data
  1245. if storage_type is torch.UntypedStorage:
  1246. dtype = torch.uint8
  1247. else:
  1248. dtype = storage_type.dtype
  1249. if key in loaded_storages:
  1250. typed_storage = loaded_storages[key]
  1251. else:
  1252. nbytes = numel * torch._utils._element_size(dtype)
  1253. typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  1254. return typed_storage
  1255. load_module_mapping: Dict[str, str] = {
  1256. # See https://github.com/pytorch/pytorch/pull/51633
  1257. 'torch.tensor': 'torch._tensor'
  1258. }
  1259. # Need to subclass Unpickler instead of directly monkey-patching the find_class method
  1260. # because it's marked readonly in pickle.
  1261. # The type: ignore is because mypy can't statically determine the type of this class.
  1262. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  1263. # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
  1264. # Lets us override the imports that pickle uses when unpickling an object.
  1265. # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
  1266. def find_class(self, mod_name, name):
  1267. if type(name) is str and 'Storage' in name:
  1268. try:
  1269. return StorageType(name)
  1270. except KeyError:
  1271. pass
  1272. mod_name = load_module_mapping.get(mod_name, mod_name)
  1273. return super().find_class(mod_name, name)
  1274. # Load the data (which may in turn use `persistent_load` to load tensors)
  1275. data_file = io.BytesIO(zip_file.get_record(pickle_file))
  1276. unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
  1277. unpickler.persistent_load = persistent_load
  1278. # Needed for tensors where storage device and rebuild tensor device are
  1279. # not connected (wrapper subclasses and tensors rebuilt using numpy)
  1280. torch._utils._thread_local_state.map_location = map_location
  1281. result = unpickler.load()
  1282. del torch._utils._thread_local_state.map_location
  1283. torch._utils._validate_loaded_sparse_tensors()
  1284. torch._C._log_api_usage_metadata(
  1285. "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
  1286. )
  1287. return result
  1288. def _is_torchscript_zip(zip_file):
  1289. return 'constants.pkl' in zip_file.get_all_records()