_deploy.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # mypy: allow-untyped-defs
  2. import io
  3. import torch
  4. from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
  5. from torch.package._package_pickler import create_pickler
  6. from torch.package._package_unpickler import PackageUnpickler
  7. from torch.serialization import _maybe_decode_ascii
  8. def _save_storages(importer, obj):
  9. serialized_storages = []
  10. serialized_dtypes = []
  11. importer = importer if isinstance(importer, torch.package.PackageImporter) else None
  12. importers: Importer
  13. if importer is not None:
  14. importers = OrderedImporter(importer, sys_importer)
  15. else:
  16. importers = sys_importer
  17. def persistent_id(obj):
  18. if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
  19. if isinstance(obj, torch.storage.TypedStorage):
  20. # TODO: Once we decide to break serialization FC, we can
  21. # remove this case
  22. storage = obj._untyped_storage
  23. dtype = obj.dtype
  24. else:
  25. storage = obj
  26. dtype = torch.uint8
  27. serialized_storages.append(obj)
  28. serialized_dtypes.append(dtype)
  29. return ("storage", len(serialized_storages) - 1)
  30. if hasattr(obj, "__reduce_deploy__"):
  31. if _serialized_reduces.get(id(obj)) is None:
  32. _serialized_reduces[id(obj)] = (
  33. "reduce_deploy",
  34. id(obj),
  35. *obj.__reduce_deploy__(importers),
  36. )
  37. return _serialized_reduces[id(obj)]
  38. return None
  39. # Write the pickle data for `obj`
  40. data_buf = io.BytesIO()
  41. pickler = create_pickler(data_buf, importers)
  42. pickler.persistent_id = persistent_id
  43. pickler.dump(obj)
  44. data_value = data_buf.getvalue()
  45. return (
  46. data_value,
  47. serialized_storages,
  48. serialized_dtypes,
  49. importer.zip_reader if importer else None,
  50. )
  51. def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
  52. def persistent_load(saved_id):
  53. assert isinstance(saved_id, tuple)
  54. typename = _maybe_decode_ascii(saved_id[0])
  55. data = saved_id[1:]
  56. if typename == "storage":
  57. # TODO: Once we decide to break serialization FC, we can
  58. # stop wrapping with TypedStorage
  59. storage = serialized_storages[data[0]]
  60. dtype = serialized_dtypes[data[0]]
  61. return torch.storage.TypedStorage(
  62. wrap_storage=storage.untyped(), dtype=dtype
  63. )
  64. if typename == "reduce_deploy":
  65. reduce_id, func, args = data
  66. if reduce_id not in _loaded_reduces:
  67. _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
  68. return _loaded_reduces[reduce_id]
  69. return None
  70. importer: Importer
  71. if zip_reader is not None:
  72. importer = OrderedImporter(_get_package(zip_reader), sys_importer)
  73. else:
  74. importer = sys_importer
  75. unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
  76. unpickler.persistent_load = persistent_load # type: ignore[method-assign]
  77. result = _deploy_objects[id] = unpickler.load()
  78. return result
  79. def _get_package(zip_reader):
  80. if zip_reader not in _raw_packages:
  81. _raw_packages[zip_reader] = PackageImporter(zip_reader)
  82. return _raw_packages[zip_reader]
  83. _raw_packages: dict = {}
  84. _deploy_objects: dict = {}
  85. _serialized_reduces: dict = {}
  86. _loaded_reduces: dict = {}