weak.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import weakref
  4. from weakref import ref
  5. from _weakrefset import _IterationGuard # type: ignore[attr-defined]
  6. from collections.abc import MutableMapping, Mapping
  7. from torch import Tensor
  8. import collections.abc as _collections_abc
  9. WeakRef = ref
  10. __all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
  11. # This file defines a variant of WeakKeyDictionary that overrides the hashing
  12. # behavior of the key to use object identity, rather than the builtin
  13. # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
  14. # __eq__ implementation return a Tensor (elementwise equality), which means
  15. # you can't use them directly with the WeakKeyDictionary in standard library.
  16. #
  17. # Our implementation strategy is to create a wrapper weak key object, which we
  18. # use as a key in a stock Python dictionary. This is similar to how weakref
  19. # implements WeakKeyDictionary, but instead of using weakref.ref as the
  20. # wrapper, we use a custom wrapper that has different __eq__ and __hash__
  21. # behavior. Note that we subsequently store this weak key directly in an
  22. # ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
  23. # be a dictionary so it would have no strong references. Ensuring that
  24. # only live WeakIdKeys are in the map is handled by putting finalizers on the
  25. # original key object.
  26. # It is simpler to implement this with composition, but if we want to
  27. # directly reuse the callback mechanism on weakref, we need the weakref
  28. # and the key to be exactly the same object. Reusing the callback mechanism
  29. # minimizes the divergence between our implementation and Lib/weakref.py
  30. #
  31. # NB: Prefer using this when working with weakrefs of Tensors; e.g., do
  32. # WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
  33. # easy to get wrong cases transparently for you.
  34. class WeakIdRef(weakref.ref):
  35. __slots__ = ['_id']
  36. def __init__(self, key, callback=None):
  37. # Unlike stock weakref, which preserves hash semantics of the
  38. # original object but lazily defers hash calls until the first
  39. # time the user attempts to hash the weakref, we can eagerly
  40. # cache the id of the key as we know this is definitely the hash
  41. # method
  42. self._id = id(key)
  43. super().__init__(key, callback) # type: ignore[call-arg]
  44. def __call__(self):
  45. r = super().__call__()
  46. # Special logic for Tensor PyObject resurrection
  47. if hasattr(r, '_fix_weakref'):
  48. r._fix_weakref() # type: ignore[union-attr]
  49. return r
  50. def __hash__(self):
  51. return self._id
  52. def __eq__(self, other):
  53. # An attractive but wrong alternate implementation is to only test if
  54. # the stored _ids match. This can lead to an ABA problem if you have:
  55. #
  56. # a1 = A()
  57. # w1 = WeakIdRef(a1)
  58. # del a1
  59. # a2 = A() # suppose it gets the same ID as a1
  60. # w2 = WeakIdRef(a2)
  61. # print(w1 == w2)
  62. #
  63. # This should be False, as a1 and a2 are unrelated (and a1 is
  64. # dead anyway)
  65. a = self()
  66. b = other()
  67. if a is not None and b is not None:
  68. return a is b
  69. return self is other
  70. # This is the same as WeakIdRef but equality is checked using hash() rather than id.
  71. # This will be equivalent to the one above except for classes where hash is not their id.
  72. class _WeakHashRef(weakref.ref):
  73. __slots__ = ['_id']
  74. def __init__(self, key, callback=None):
  75. # Unlike stock weakref, which preserves hash semantics of the
  76. # original object but lazily defers hash calls until the first
  77. # time the user attempts to hash the weakref, we can eagerly
  78. # cache the id of the key as we know this is definitely the hash
  79. # method
  80. self._id = hash(key)
  81. super().__init__(key, callback) # type: ignore[call-arg]
  82. def __call__(self):
  83. r = super().__call__()
  84. # Special logic for Tensor PyObject resurrection
  85. if hasattr(r, '_fix_weakref'):
  86. r._fix_weakref() # type: ignore[union-attr]
  87. return r
  88. def __hash__(self):
  89. return self._id
  90. def __eq__(self, other):
  91. # Use hash equality to determine ref equality.
  92. # ScriptObject implements __hash__ to return the wrapped IValue's id, so
  93. # this is equivalent to doing an identity comparison.
  94. a = self()
  95. b = other()
  96. if a is not None and b is not None:
  97. return hash(a) == hash(b)
  98. return self is other
  99. # This is directly adapted from cpython/Lib/weakref.py
  100. class WeakIdKeyDictionary(MutableMapping):
  101. def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED
  102. self.data = {}
  103. self.ref_type = ref_type # CHANGED
  104. def remove(k, selfref=ref(self)):
  105. self = selfref()
  106. if self is not None:
  107. if self._iterating:
  108. self._pending_removals.append(k)
  109. else:
  110. try:
  111. del self.data[k]
  112. except KeyError:
  113. pass
  114. self._remove = remove
  115. # A list of dead weakrefs (keys to be removed)
  116. self._pending_removals = []
  117. self._iterating = set()
  118. self._dirty_len = False
  119. if dict is not None:
  120. self.update(dict)
  121. def _commit_removals(self):
  122. # NOTE: We don't need to call this method before mutating the dict,
  123. # because a dead weakref never compares equal to a live weakref,
  124. # even if they happened to refer to equal objects.
  125. # However, it means keys may already have been removed.
  126. pop = self._pending_removals.pop
  127. d = self.data
  128. while True:
  129. try:
  130. key = pop()
  131. except IndexError:
  132. return
  133. try:
  134. del d[key]
  135. except KeyError:
  136. pass
  137. def _scrub_removals(self):
  138. d = self.data
  139. self._pending_removals = [k for k in self._pending_removals if k in d]
  140. self._dirty_len = False
  141. def __delitem__(self, key):
  142. self._dirty_len = True
  143. del self.data[self.ref_type(key)] # CHANGED
  144. def __getitem__(self, key):
  145. return self.data[self.ref_type(key)] # CHANGED
  146. def __len__(self):
  147. if self._dirty_len and self._pending_removals:
  148. # self._pending_removals may still contain keys which were
  149. # explicitly removed, we have to scrub them (see issue #21173).
  150. self._scrub_removals()
  151. return len(self.data) - len(self._pending_removals)
  152. def __repr__(self):
  153. return f"<{self.__class__.__name__} at {id(self):#x}>"
  154. def __setitem__(self, key, value):
  155. self.data[self.ref_type(key, self._remove)] = value # CHANGED
  156. def copy(self):
  157. new = WeakIdKeyDictionary()
  158. with _IterationGuard(self):
  159. for key, value in self.data.items():
  160. o = key()
  161. if o is not None:
  162. new[o] = value
  163. return new
  164. __copy__ = copy
  165. def __deepcopy__(self, memo):
  166. from copy import deepcopy
  167. new = self.__class__()
  168. with _IterationGuard(self):
  169. for key, value in self.data.items():
  170. o = key()
  171. if o is not None:
  172. new[o] = deepcopy(value, memo)
  173. return new
  174. def get(self, key, default=None):
  175. return self.data.get(self.ref_type(key), default) # CHANGED
  176. def __contains__(self, key):
  177. try:
  178. wr = self.ref_type(key) # CHANGED
  179. except TypeError:
  180. return False
  181. return wr in self.data
  182. def items(self):
  183. with _IterationGuard(self):
  184. for wr, value in self.data.items():
  185. key = wr()
  186. if key is not None:
  187. yield key, value
  188. def keys(self):
  189. with _IterationGuard(self):
  190. for wr in self.data:
  191. obj = wr()
  192. if obj is not None:
  193. yield obj
  194. __iter__ = keys
  195. def values(self):
  196. with _IterationGuard(self):
  197. for wr, value in self.data.items():
  198. if wr() is not None:
  199. yield value
  200. def keyrefs(self):
  201. """Return a list of weak references to the keys.
  202. The references are not guaranteed to be 'live' at the time
  203. they are used, so the result of calling the references needs
  204. to be checked before being used. This can be used to avoid
  205. creating references that will cause the garbage collector to
  206. keep the keys around longer than needed.
  207. """
  208. return list(self.data)
  209. def popitem(self):
  210. self._dirty_len = True
  211. while True:
  212. key, value = self.data.popitem()
  213. o = key()
  214. if o is not None:
  215. return o, value
  216. def pop(self, key, *args):
  217. self._dirty_len = True
  218. return self.data.pop(self.ref_type(key), *args) # CHANGED
  219. def setdefault(self, key, default=None):
  220. return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
  221. def update(self, dict=None, **kwargs):
  222. d = self.data
  223. if dict is not None:
  224. if not hasattr(dict, "items"):
  225. dict = type({})(dict)
  226. for key, value in dict.items():
  227. d[self.ref_type(key, self._remove)] = value # CHANGED
  228. if len(kwargs):
  229. self.update(kwargs)
  230. def __ior__(self, other):
  231. self.update(other)
  232. return self
  233. def __or__(self, other):
  234. if isinstance(other, _collections_abc.Mapping):
  235. c = self.copy()
  236. c.update(other)
  237. return c
  238. return NotImplemented
  239. def __ror__(self, other):
  240. if isinstance(other, _collections_abc.Mapping):
  241. c = self.__class__()
  242. c.update(other)
  243. c.update(self)
  244. return c
  245. return NotImplemented
  246. # Default Mapping equality will tests keys for equality, but
  247. # we want to test ids for equality
  248. def __eq__(self, other):
  249. if not isinstance(other, Mapping):
  250. return NotImplemented
  251. return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
  252. # Convenience alias
  253. WeakTensorKeyDictionary = WeakIdKeyDictionary
  254. class TensorWeakRef:
  255. """Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
  256. ref: WeakRef[Tensor]
  257. def __init__(self, tensor: Tensor):
  258. assert isinstance(tensor, Tensor)
  259. self.ref = weakref.ref(tensor)
  260. def __call__(self):
  261. out = self.ref()
  262. if out is None:
  263. return out
  264. assert isinstance(out, Tensor)
  265. # TODO, add _fix_weakref type binding
  266. out._fix_weakref() # type: ignore[attr-defined]
  267. return out