lazy.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # mypy: ignore-errors
  2. import collections
  3. import functools
  4. from typing import Optional
  5. from .base import VariableTracker
  6. class LazyCache:
  7. """Container to cache the real VariableTracker"""
  8. def __init__(self, value, source):
  9. assert source
  10. self.value = value
  11. self.source = source
  12. self.vt: Optional[VariableTracker] = None
  13. def realize(self):
  14. assert self.vt is None
  15. from ..symbolic_convert import InstructionTranslator
  16. from .builder import VariableBuilder
  17. tx = InstructionTranslator.current_tx()
  18. self.vt = VariableBuilder(tx, self.source)(self.value)
  19. del self.value
  20. del self.source
  21. class LazyVariableTracker(VariableTracker):
  22. """
  23. A structure that defers the creation of the actual VariableTracker
  24. for a given underlying value until it is accessed.
  25. The `realize` function invokes VariableBuilder to produce the real object.
  26. Once a LazyVariableTracker has been realized, internal bookkeeping will
  27. prevent double realization.
  28. This object should be utilized for processing containers, or objects that
  29. reference other objects where we may not want to take on creating all the
  30. VariableTrackers right away.
  31. """
  32. _nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
  33. @staticmethod
  34. def create(value, source, **options):
  35. return LazyVariableTracker(LazyCache(value, source), source=source, **options)
  36. def __init__(self, _cache, **kwargs):
  37. assert isinstance(_cache, LazyCache)
  38. super().__init__(**kwargs)
  39. self._cache = _cache
  40. def realize(self) -> VariableTracker:
  41. """Force construction of the real VariableTracker"""
  42. if self._cache.vt is None:
  43. self._cache.realize()
  44. return self._cache.vt
  45. def unwrap(self):
  46. """Return the real VariableTracker if it already exists"""
  47. if self.is_realized():
  48. return self._cache.vt
  49. return self
  50. def is_realized(self):
  51. return self._cache.vt is not None
  52. def clone(self, **kwargs):
  53. assert kwargs.get("_cache", self._cache) is self._cache
  54. if kwargs.get("source", self.source) is not self.source:
  55. self.realize()
  56. return VariableTracker.clone(self.unwrap(), **kwargs)
  57. def __str__(self):
  58. if self.is_realized():
  59. return self.unwrap().__str__()
  60. return VariableTracker.__str__(self.unwrap())
  61. def __getattr__(self, item):
  62. return getattr(self.realize(), item)
  63. # most methods are auto-generated below, these are the ones we want to exclude
  64. visit = VariableTracker.visit
  65. __repr__ = VariableTracker.__repr__
  66. @classmethod
  67. def realize_all(
  68. cls,
  69. value,
  70. cache=None,
  71. ):
  72. """
  73. Walk an object and realize all LazyVariableTrackers inside it.
  74. """
  75. if cache is None:
  76. cache = dict()
  77. idx = id(value)
  78. if idx in cache:
  79. return cache[idx][0]
  80. value_cls = type(value)
  81. if issubclass(value_cls, LazyVariableTracker):
  82. result = cls.realize_all(value.realize(), cache)
  83. elif issubclass(value_cls, VariableTracker):
  84. # update value in-place
  85. result = value
  86. value_dict = value.__dict__
  87. nonvars = value._nonvar_fields
  88. for key in value_dict:
  89. if key not in nonvars:
  90. value_dict[key] = cls.realize_all(value_dict[key], cache)
  91. elif value_cls is list:
  92. result = [cls.realize_all(v, cache) for v in value]
  93. elif value_cls is tuple:
  94. result = tuple(cls.realize_all(v, cache) for v in value)
  95. elif value_cls in (dict, collections.OrderedDict):
  96. result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
  97. else:
  98. result = value
  99. # save `value` to keep it alive and ensure id() isn't reused
  100. cache[idx] = (result, value)
  101. return result
  102. def _create_realize_and_forward(name):
  103. @functools.wraps(getattr(VariableTracker, name))
  104. def realize_and_forward(self, *args, **kwargs):
  105. return getattr(self.realize(), name)(*args, **kwargs)
  106. return realize_and_forward
  107. def _populate():
  108. for name, value in VariableTracker.__dict__.items():
  109. if name not in LazyVariableTracker.__dict__:
  110. if callable(value):
  111. setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
  112. _populate()