_directory_reader.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # mypy: allow-untyped-defs
  2. import os.path
  3. from glob import glob
  4. from typing import cast
  5. import torch
  6. from torch.types import Storage
  7. __serialization_id_record_name__ = ".data/serialization_id"
  8. # because get_storage_from_record returns a tensor!?
  9. class _HasStorage:
  10. def __init__(self, storage):
  11. self._storage = storage
  12. def storage(self):
  13. return self._storage
  14. class DirectoryReader:
  15. """
  16. Class to allow PackageImporter to operate on unzipped packages. Methods
  17. copy the behavior of the internal PyTorchFileReader class (which is used for
  18. accessing packages in all other cases).
  19. N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
  20. class due to ScriptObjects requiring an actual PyTorchFileReader instance.
  21. """
  22. def __init__(self, directory):
  23. self.directory = directory
  24. def get_record(self, name):
  25. filename = f"{self.directory}/{name}"
  26. with open(filename, "rb") as f:
  27. return f.read()
  28. def get_storage_from_record(self, name, numel, dtype):
  29. filename = f"{self.directory}/{name}"
  30. nbytes = torch._utils._element_size(dtype) * numel
  31. storage = cast(Storage, torch.UntypedStorage)
  32. return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
  33. def has_record(self, path):
  34. full_path = os.path.join(self.directory, path)
  35. return os.path.isfile(full_path)
  36. def get_all_records(
  37. self,
  38. ):
  39. files = []
  40. for filename in glob(f"{self.directory}/**", recursive=True):
  41. if not os.path.isdir(filename):
  42. files.append(filename[len(self.directory) + 1 :])
  43. return files
  44. def serialization_id(
  45. self,
  46. ):
  47. if self.has_record(__serialization_id_record_name__):
  48. return self.get_record(__serialization_id_record_name__)
  49. else:
  50. return ""