show_pickle.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. import sys
  4. import pickle
  5. import struct
  6. import pprint
  7. import zipfile
  8. import fnmatch
  9. from typing import Any, IO, BinaryIO, Union
  10. __all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"]
  11. class FakeObject:
  12. def __init__(self, module, name, args):
  13. self.module = module
  14. self.name = name
  15. self.args = args
  16. # NOTE: We don't distinguish between state never set and state set to None.
  17. self.state = None
  18. def __repr__(self):
  19. state_str = "" if self.state is None else f"(state={self.state!r})"
  20. return f"{self.module}.{self.name}{self.args!r}{state_str}"
  21. def __setstate__(self, state):
  22. self.state = state
  23. @staticmethod
  24. def pp_format(printer, obj, stream, indent, allowance, context, level):
  25. if not obj.args and obj.state is None:
  26. stream.write(repr(obj))
  27. return
  28. if obj.state is None:
  29. stream.write(f"{obj.module}.{obj.name}")
  30. printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
  31. return
  32. if not obj.args:
  33. stream.write(f"{obj.module}.{obj.name}()(state=\n")
  34. indent += printer._indent_per_level
  35. stream.write(" " * indent)
  36. printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
  37. stream.write(")")
  38. return
  39. raise Exception("Need to implement") # noqa: TRY002
  40. class FakeClass:
  41. def __init__(self, module, name):
  42. self.module = module
  43. self.name = name
  44. self.__new__ = self.fake_new # type: ignore[assignment]
  45. def __repr__(self):
  46. return f"{self.module}.{self.name}"
  47. def __call__(self, *args):
  48. return FakeObject(self.module, self.name, args)
  49. def fake_new(self, *args):
  50. return FakeObject(self.module, self.name, args[1:])
  51. class DumpUnpickler(pickle._Unpickler): # type: ignore[name-defined]
  52. def __init__(
  53. self,
  54. file,
  55. *,
  56. catch_invalid_utf8=False,
  57. **kwargs):
  58. super().__init__(file, **kwargs)
  59. self.catch_invalid_utf8 = catch_invalid_utf8
  60. def find_class(self, module, name):
  61. return FakeClass(module, name)
  62. def persistent_load(self, pid):
  63. return FakeObject("pers", "obj", (pid,))
  64. dispatch = dict(pickle._Unpickler.dispatch) # type: ignore[attr-defined]
  65. # Custom objects in TorchScript are able to return invalid UTF-8 strings
  66. # from their pickle (__getstate__) functions. Install a custom loader
  67. # for strings that catches the decode exception and replaces it with
  68. # a sentinel object.
  69. def load_binunicode(self):
  70. strlen, = struct.unpack("<I", self.read(4)) # type: ignore[attr-defined]
  71. if strlen > sys.maxsize:
  72. raise Exception("String too long.") # noqa: TRY002
  73. str_bytes = self.read(strlen) # type: ignore[attr-defined]
  74. obj: Any
  75. try:
  76. obj = str(str_bytes, "utf-8", "surrogatepass")
  77. except UnicodeDecodeError as exn:
  78. if not self.catch_invalid_utf8:
  79. raise
  80. obj = FakeObject("builtin", "UnicodeDecodeError", (str(exn),))
  81. self.append(obj) # type: ignore[attr-defined]
  82. dispatch[pickle.BINUNICODE[0]] = load_binunicode # type: ignore[assignment]
  83. @classmethod
  84. def dump(cls, in_stream, out_stream):
  85. value = cls(in_stream).load()
  86. pprint.pprint(value, stream=out_stream)
  87. return value
  88. def main(argv, output_stream=None):
  89. if len(argv) != 2:
  90. # Don't spam stderr if not using stdout.
  91. if output_stream is not None:
  92. raise Exception("Pass argv of length 2.") # noqa: TRY002
  93. sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
  94. sys.stderr.write(" PICKLE_FILE can be any of:\n")
  95. sys.stderr.write(" path to a pickle file\n")
  96. sys.stderr.write(" file.zip@member.pkl\n")
  97. sys.stderr.write(" file.zip@*/pattern.*\n")
  98. sys.stderr.write(" (shell glob pattern for members)\n")
  99. sys.stderr.write(" (only first match will be shown)\n")
  100. return 2
  101. fname = argv[1]
  102. handle: Union[IO[bytes], BinaryIO]
  103. if "@" not in fname:
  104. with open(fname, "rb") as handle:
  105. DumpUnpickler.dump(handle, output_stream)
  106. else:
  107. zfname, mname = fname.split("@", 1)
  108. with zipfile.ZipFile(zfname) as zf:
  109. if "*" not in mname:
  110. with zf.open(mname) as handle:
  111. DumpUnpickler.dump(handle, output_stream)
  112. else:
  113. found = False
  114. for info in zf.infolist():
  115. if fnmatch.fnmatch(info.filename, mname):
  116. with zf.open(info) as handle:
  117. DumpUnpickler.dump(handle, output_stream)
  118. found = True
  119. break
  120. if not found:
  121. raise Exception(f"Could not find member matching {mname} in {zfname}") # noqa: TRY002
  122. if __name__ == "__main__":
  123. # This hack works on every version of Python I've tested.
  124. # I've tested on the following versions:
  125. # 3.7.4
  126. if True:
  127. pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format # type: ignore[attr-defined]
  128. sys.exit(main(sys.argv))