inline_container.h 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #pragma once
  2. #include <cerrno>
  3. #include <cstdio>
  4. #include <cstring>
  5. #include <fstream>
  6. #include <istream>
  7. #include <mutex>
  8. #include <ostream>
  9. #include <unordered_set>
  10. #include <c10/core/Allocator.h>
  11. #include <c10/core/Backend.h>
  12. #include "caffe2/serialize/istream_adapter.h"
  13. #include "caffe2/serialize/read_adapter_interface.h"
  14. #include "caffe2/serialize/versions.h"
  15. extern "C" {
  16. typedef struct mz_zip_archive mz_zip_archive;
  17. }
  18. // PyTorch containers are a special zip archive with the following layout
  19. // archive_name.zip contains:
  20. // archive_name/
  21. // version # a file with a single decimal number written in ascii,
  22. // # used to establish the version of the archive format
  23. // model.json # overall model description, this is a json output of
  24. // # ModelDef from torch.proto
  25. // # the following names are by convention only, model.json will
  26. // # refer to these files by full names
  27. // tensors/
  28. // 0 # flat storage for tensor data, meta-data about shapes, etc. is
  29. // # in model.json
  30. // 1
  31. // ...
  32. // # code entries will only exist for modules that have methods attached
  33. // code/
  34. // archive_name.py # serialized torch script code (python syntax, using
  35. // PythonPrint) archive_name_my_submodule.py # submodules have separate
  36. // files
  37. //
  38. // The PyTorchStreamWriter also ensures additional useful properties for these
  39. // files
  40. // 1. All files are stored uncompressed.
  41. // 2. All files in the archive are aligned to 64 byte boundaries such that
  42. // it is possible to mmap the entire file and get an aligned pointer to
  43. // tensor data.
  44. // 3. We universally write in ZIP64 format for consistency.
  45. // The PyTorchStreamReader also provides additional properties:
  46. // 1. It can read zip files that are created with common
  47. // zip tools. This means that even though our writer doesn't compress files,
  48. // the reader can still read files that were compressed.
  49. // 2. It provides a getRecordOffset function which returns the offset into the
  50. // raw file where file data lives. If the file was written with
  51. // PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
  52. // PyTorchReader/Writer handle checking the version number on the archive format
  53. // and ensure that all files are written to a archive_name directory so they
  54. // unzip cleanly.
  55. // When developing this format we want to pay particular attention to the
  56. // following use cases:
  57. //
  58. // -- Reading --
  59. // 1) Reading with full random access
  60. // a) Reading with file api's such as fread()
  61. // b) mmaping the file and jumping around the mapped region
  62. // 2) Reading with 1-pass sequential access
  63. // -> A reader will need to build up a data structure of parsed structures
  64. // as it reads
  65. //
  66. // -- Writing --
  67. // 1) Writing with full random access
  68. // 2) Writing with 1-pass sequential access
  69. // -> We must take care not to require updating values that have already
  70. // been written. We place the variable-length index at the end and do
  71. // not put any indicies into the header to fulfill this constraint.
  72. // The model.json, which contains all the metadata information,
  73. // should be written as the last file. One reason is that the size of tensor
  74. // data is usually stable. As long as the shape and type of the tensor do not
  75. // change, the size of the data won't change. On the other sied, the size of the
  76. // serialized model is likely to change, so we store it as the last record, and
  77. // we don't need to move previous records when updating the model data.
  78. // The zip format is sufficiently flexible to handle the above use-case.
  79. // it puts its central directory at the end of the archive and we write
  80. // model.json as the last file when writing after we have accumulated all
  81. // other information.
  82. namespace caffe2 {
  83. namespace serialize {
  84. static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
  85. struct MzZipReaderIterWrapper;
  86. class TORCH_API ChunkRecordIterator {
  87. public:
  88. ~ChunkRecordIterator();
  89. // Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
  90. size_t next(void* buf);
  91. size_t recordSize() const { return recordSize_; }
  92. private:
  93. ChunkRecordIterator(
  94. size_t recordSize,
  95. size_t chunkSize,
  96. std::unique_ptr<MzZipReaderIterWrapper> iter);
  97. const size_t recordSize_;
  98. const size_t chunkSize_;
  99. size_t offset_;
  100. std::unique_ptr<MzZipReaderIterWrapper> iter_;
  101. friend class PyTorchStreamReader;
  102. };
  103. class TORCH_API PyTorchStreamReader final {
  104. public:
  105. explicit PyTorchStreamReader(const std::string& file_name);
  106. explicit PyTorchStreamReader(std::istream* in);
  107. explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
  108. // return dataptr, size
  109. std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
  110. // multi-thread getRecord
  111. std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
  112. // inplace memory writing
  113. size_t getRecord(const std::string& name, void* dst, size_t n);
  114. // inplace memory writing, multi-threads.
  115. // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
  116. // This approach can be used for reading large tensors.
  117. size_t getRecord(const std::string& name, void* dst, size_t n,
  118. std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
  119. size_t getRecord(
  120. const std::string& name,
  121. void* dst,
  122. size_t n,
  123. size_t chunk_size,
  124. void* buf,
  125. const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
  126. // Concurrent reading records with multiple readers.
  127. // additionalReaders are additional clients to access the underlying record at different offsets
  128. // and write to different trunks of buffers.
  129. // If the overall size of the tensor is 10, and size of additionalReader is 2.
  130. // The default thread will read [0,4), the additional reader will read [4,8).
  131. // The default reader will read [8,10).
  132. // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
  133. // the additional reader will write to buffer[8,10).
  134. // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
  135. // This approach can be used for reading large tensors.
  136. size_t getRecordMultiReaders(const std::string& name,
  137. std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
  138. void *dst, size_t n);
  139. size_t getRecordSize(const std::string& name);
  140. size_t getRecordOffset(const std::string& name);
  141. bool hasRecord(const std::string& name);
  142. std::vector<std::string> getAllRecords();
  143. ChunkRecordIterator createChunkReaderIter(
  144. const std::string& name,
  145. const size_t recordSize,
  146. const size_t chunkSize);
  147. ~PyTorchStreamReader();
  148. uint64_t version() const {
  149. return version_;
  150. }
  151. const std::string& serializationId() {
  152. return serialization_id_;
  153. }
  154. void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
  155. load_debug_symbol_ = should_load_debug_symbol;
  156. }
  157. void setAdditionalReaderSizeThreshold(const size_t& size){
  158. additional_reader_size_threshold_ = size;
  159. }
  160. private:
  161. void init();
  162. size_t read(uint64_t pos, char* buf, size_t n);
  163. void valid(const char* what, const char* info = "");
  164. size_t getRecordID(const std::string& name);
  165. friend size_t
  166. istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
  167. std::unique_ptr<mz_zip_archive> ar_;
  168. std::string archive_name_;
  169. std::string archive_name_plus_slash_;
  170. std::shared_ptr<ReadAdapterInterface> in_;
  171. int64_t version_;
  172. std::mutex reader_lock_;
  173. bool load_debug_symbol_ = true;
  174. std::string serialization_id_;
  175. size_t additional_reader_size_threshold_;
  176. };
  177. class TORCH_API PyTorchStreamWriter final {
  178. public:
  179. explicit PyTorchStreamWriter(const std::string& archive_name);
  180. explicit PyTorchStreamWriter(
  181. const std::function<size_t(const void*, size_t)> writer_func);
  182. void setMinVersion(const uint64_t version);
  183. void writeRecord(
  184. const std::string& name,
  185. const void* data,
  186. size_t size,
  187. bool compress = false);
  188. void writeEndOfFile();
  189. const std::unordered_set<std::string>& getAllWrittenRecords();
  190. bool finalized() const {
  191. return finalized_;
  192. }
  193. const std::string& archiveName() {
  194. return archive_name_;
  195. }
  196. const std::string& serializationId() {
  197. return serialization_id_;
  198. }
  199. ~PyTorchStreamWriter();
  200. private:
  201. void setup(const std::string& file_name);
  202. void valid(const char* what, const char* info = "");
  203. void writeSerializationId();
  204. size_t current_pos_ = 0;
  205. std::unordered_set<std::string> files_written_;
  206. std::unique_ptr<mz_zip_archive> ar_;
  207. std::string archive_name_;
  208. std::string archive_name_plus_slash_;
  209. std::string padding_;
  210. std::ofstream file_stream_;
  211. std::function<size_t(const void*, size_t)> writer_func_;
  212. uint64_t combined_uncomp_crc32_ = 0;
  213. std::string serialization_id_;
  214. // This number will be updated when the model has operators
  215. // that have valid upgraders.
  216. uint64_t version_ = kMinProducedFileFormatVersion;
  217. bool finalized_ = false;
  218. bool err_seen_ = false;
  219. friend size_t ostream_write_func(
  220. void* pOpaque,
  221. uint64_t file_ofs,
  222. const void* pBuf,
  223. size_t n);
  224. };
  225. namespace detail {
  226. // Writer-specific constants
  227. constexpr uint64_t kFieldAlignment = 64;
  228. // Returns a record to be appended to the local user extra data entry in order
  229. // to make data beginning aligned at kFieldAlignment bytes boundary.
  230. size_t getPadding(
  231. size_t cursor,
  232. size_t filename_size,
  233. size_t size,
  234. std::string& padding_buf);
  235. } // namespace detail
  236. } // namespace serialize
  237. } // namespace caffe2