_package_pickler.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # mypy: allow-untyped-defs
  2. """isort:skip_file"""
  3. from pickle import ( # type: ignore[attr-defined]
  4. _compat_pickle,
  5. _extension_registry,
  6. _getattribute,
  7. _Pickler,
  8. EXT1,
  9. EXT2,
  10. EXT4,
  11. GLOBAL,
  12. Pickler,
  13. PicklingError,
  14. STACK_GLOBAL,
  15. )
  16. from struct import pack
  17. from types import FunctionType
  18. from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
  19. class PackagePickler(_Pickler):
  20. """Package-aware pickler.
  21. This behaves the same as a normal pickler, except it uses an `Importer`
  22. to find objects and modules to save.
  23. """
  24. def __init__(self, importer: Importer, *args, **kwargs):
  25. self.importer = importer
  26. super().__init__(*args, **kwargs)
  27. # Make sure the dispatch table copied from _Pickler is up-to-date.
  28. # Previous issues have been encountered where a library (e.g. dill)
  29. # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
  30. # is imported, then the offending library removes its dispatch entries,
  31. # leaving PackagePickler with a stale dispatch table that may cause
  32. # unwanted behavior.
  33. self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
  34. self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
  35. def save_global(self, obj, name=None):
  36. # unfortunately the pickler code is factored in a way that
  37. # forces us to copy/paste this function. The only change is marked
  38. # CHANGED below.
  39. write = self.write # type: ignore[attr-defined]
  40. memo = self.memo # type: ignore[attr-defined]
  41. # CHANGED: import module from module environment instead of __import__
  42. try:
  43. module_name, name = self.importer.get_name(obj, name)
  44. except (ObjNotFoundError, ObjMismatchError) as err:
  45. raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
  46. module = self.importer.import_module(module_name)
  47. _, parent = _getattribute(module, name)
  48. # END CHANGED
  49. if self.proto >= 2: # type: ignore[attr-defined]
  50. code = _extension_registry.get((module_name, name))
  51. if code:
  52. assert code > 0
  53. if code <= 0xFF:
  54. write(EXT1 + pack("<B", code))
  55. elif code <= 0xFFFF:
  56. write(EXT2 + pack("<H", code))
  57. else:
  58. write(EXT4 + pack("<i", code))
  59. return
  60. lastname = name.rpartition(".")[2]
  61. if parent is module:
  62. name = lastname
  63. # Non-ASCII identifiers are supported only with protocols >= 3.
  64. if self.proto >= 4: # type: ignore[attr-defined]
  65. self.save(module_name) # type: ignore[attr-defined]
  66. self.save(name) # type: ignore[attr-defined]
  67. write(STACK_GLOBAL)
  68. elif parent is not module:
  69. self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
  70. elif self.proto >= 3: # type: ignore[attr-defined]
  71. write(
  72. GLOBAL
  73. + bytes(module_name, "utf-8")
  74. + b"\n"
  75. + bytes(name, "utf-8")
  76. + b"\n"
  77. )
  78. else:
  79. if self.fix_imports: # type: ignore[attr-defined]
  80. r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
  81. r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
  82. if (module_name, name) in r_name_mapping:
  83. module_name, name = r_name_mapping[(module_name, name)]
  84. elif module_name in r_import_mapping:
  85. module_name = r_import_mapping[module_name]
  86. try:
  87. write(
  88. GLOBAL
  89. + bytes(module_name, "ascii")
  90. + b"\n"
  91. + bytes(name, "ascii")
  92. + b"\n"
  93. )
  94. except UnicodeEncodeError:
  95. raise PicklingError(
  96. "can't pickle global identifier '%s.%s' using "
  97. "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
  98. ) from None
  99. self.memoize(obj) # type: ignore[attr-defined]
  100. def create_pickler(data_buf, importer, protocol=4):
  101. if importer is sys_importer:
  102. # if we are using the normal import library system, then
  103. # we can use the C implementation of pickle which is faster
  104. return Pickler(data_buf, protocol=protocol)
  105. else:
  106. return PackagePickler(importer, data_buf, protocol=protocol)