_mangling.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # mypy: allow-untyped-defs
  2. """Import mangling.
  3. See mangling.md for details.
  4. """
  5. import re
  6. _mangle_index = 0
  7. class PackageMangler:
  8. """
  9. Used on import, to ensure that all modules imported have a shared mangle parent.
  10. """
  11. def __init__(self):
  12. global _mangle_index
  13. self._mangle_index = _mangle_index
  14. # Increment the global index
  15. _mangle_index += 1
  16. # Angle brackets are used so that there is almost no chance of
  17. # confusing this module for a real module. Plus, it is Python's
  18. # preferred way of denoting special modules.
  19. self._mangle_parent = f"<torch_package_{self._mangle_index}>"
  20. def mangle(self, name) -> str:
  21. assert len(name) != 0
  22. return self._mangle_parent + "." + name
  23. def demangle(self, mangled: str) -> str:
  24. """
  25. Note: This only demangles names that were mangled by this specific
  26. PackageMangler. It will pass through names created by a different
  27. PackageMangler instance.
  28. """
  29. if mangled.startswith(self._mangle_parent + "."):
  30. return mangled.partition(".")[2]
  31. # wasn't a mangled name
  32. return mangled
  33. def parent_name(self):
  34. return self._mangle_parent
  35. def is_mangled(name: str) -> bool:
  36. return bool(re.match(r"<torch_package_\d+>", name))
  37. def demangle(name: str) -> str:
  38. """
  39. Note: Unlike PackageMangler.demangle, this version works on any
  40. mangled name, irrespective of which PackageMangler created it.
  41. """
  42. if is_mangled(name):
  43. first, sep, last = name.partition(".")
  44. # If there is only a base mangle prefix, e.g. '<torch_package_0>',
  45. # then return an empty string.
  46. return last if len(sep) != 0 else ""
  47. return name
  48. def get_mangle_prefix(name: str) -> str:
  49. return name.partition(".")[0] if is_mangled(name) else name