| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362 |
- # mypy: ignore-errors
- import collections
- from enum import Enum
- from typing import Any, Callable, Dict, List
- from .. import variables
- from ..current_scope_id import current_scope_id
- from ..exc import unimplemented
- from ..source import AttrSource, Source
- from ..utils import istype
- class MutableLocalSource(Enum):
- """
- If the VariableTracker.mutable_local represents a Variable that:
- - already existed that Dynamo began tracking while introspection (Existing)
- - is a new variable that is created during Dynamo introspection (Local)
- """
- Existing = 0
- Local = 1
- class MutableLocalBase:
- """
- Base class for Variable.mutable_local
- """
- def __init__(self, typ: MutableLocalSource):
- # In HigherOrderOperator tracing, we need to distinguish
- # between MutableLocals inside the HigherOrderOperator and
- # ones outside it. For example, it is not safe to mutate
- # `a` in the following example because it was constructed
- # in a different scope.
- #
- # def f(x):
- # a = 1
- # def g(x):
- # nonlocal a
- # a = 2
- # return x
- # return wrap(g, x) + a
- #
- # We use self.scope to distinguish this.
- # scope == 0: The object was an existing variable
- # scope == 1: The object was created while Dynamo
- # was introspecting a function
- # (and no HigherOrderOps were involved)
- # scope >= 2: The object was created through
- # Dynamo introspection of a HigherOrderOp.
- # The exact number corresponds to the level
- # of nested HigherOrderOps.
- if typ is MutableLocalSource.Existing:
- self.scope = 0
- elif typ is MutableLocalSource.Local:
- self.scope = current_scope_id()
- else:
- unimplemented(f"Unsupported MutableLocalSource: {typ}")
- class MutableLocal(MutableLocalBase):
- """
- Marker used to indicate this (list, iter, etc) was constructed in
- local scope and can be mutated safely in analysis without leaking
- state.
- """
- def __init__(self):
- super().__init__(MutableLocalSource.Local)
- def __hash__(self):
- return id(self)
- def __eq__(self, other):
- return self is other
- def _is_top_level_scope(scope_id):
- return scope_id == 1
- def is_side_effect_safe(m: MutableLocalBase):
- scope_id = current_scope_id()
- # In the top-level scope (if no HigherOrderOperators are involved),
- # we are allowed to modify variables created in this scope as well
- # as existing variables.
- if _is_top_level_scope(scope_id):
- return True
- # Otherwise, only allow local mutation of variables created in the current scope
- return m.scope == scope_id
- class VariableTrackerMeta(type):
- all_subclasses = []
- def __instancecheck__(cls, instance) -> bool:
- """Make isinstance work with LazyVariableTracker"""
- if type.__instancecheck__(
- variables.LazyVariableTracker, instance
- ) and cls not in (
- VariableTracker,
- variables.LazyVariableTracker,
- ):
- instance = instance.realize()
- return type.__instancecheck__(cls, instance)
- def __init__(cls, name, bases, attrs):
- super().__init__(name, bases, attrs)
- VariableTrackerMeta.all_subclasses.append(cls)
- class VariableTracker(metaclass=VariableTrackerMeta):
- """
- Base class for tracked locals and stack values
- VariableTracker instances are immutable and should be copied in
- order to change them.
- """
- # fields to leave unmodified in apply()
- _nonvar_fields = {
- "value",
- "guards",
- "source",
- "mutable_local",
- "parents_tracker",
- "user_code_variable_name",
- }
- def clone(self, **kwargs):
- """Shallow copy with some (optional) changes"""
- args = dict(self.__dict__)
- args.update(kwargs)
- return self.__class__(**args)
- @classmethod
- def visit(
- cls,
- fn: Callable[["VariableTracker"], None],
- value,
- cache=None,
- ):
- """
- Walk value and call fn on all the VariableTracker instances
- """
- if cache is None:
- cache = dict()
- idx = id(value)
- if idx in cache:
- return
- # save `value` to keep it alive and ensure id() isn't reused
- cache[idx] = value
- if isinstance(value, VariableTracker):
- value = value.unwrap()
- fn(value)
- value = value.unwrap() # calling fn() might have realized it
- nonvars = value._nonvar_fields
- for key, subvalue in value.__dict__.items():
- if key not in nonvars:
- cls.visit(fn, subvalue, cache)
- elif istype(value, (list, tuple)):
- for subvalue in value:
- cls.visit(fn, subvalue, cache)
- elif istype(value, (dict, collections.OrderedDict)):
- for subvalue in value.values():
- cls.visit(fn, subvalue, cache)
- def __repr__(self):
- return f"{self.__class__.__name__}()"
- def debug_repr(self):
- # Intended to be overridden to provide more info
- try:
- return repr(self.as_python_constant())
- except NotImplementedError:
- return repr(self)
- def python_type(self):
- """
- Abstract method to be implemented by subclasses of VariableTracker.
- This method should return the type represented by the instance of the subclass.
- The purpose is to provide a standardized way to retrieve the Python type information
- of the variable being tracked.
- Returns:
- type: The Python type (such as int, str, list, etc.) of the variable tracked by
- the subclass. If the type cannot be determined or is not relevant,
- leaving it undefined or invoking super() is always sound.
- Note:
- This is an abstract method and may be overridden in subclasses.
- Example:
- class SetVariable(VariableTracker):
- def python_type(self):
- return set
- Raises:
- NotImplementedError: If the method is not implemented in a subclass.
- """
- raise NotImplementedError(f"{self} has no type")
- def as_python_constant(self):
- """For constants"""
- raise NotImplementedError(f"{self} is not a constant")
- def guard_as_python_constant(self):
- """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
- try:
- return self.as_python_constant()
- except NotImplementedError as e:
- unimplemented(str(e))
- def is_python_constant(self):
- try:
- self.as_python_constant()
- return True
- except NotImplementedError:
- return False
- def make_guard(self, fn):
- if self.source:
- return self.source.make_guard(fn)
- raise NotImplementedError
- def const_getattr(self, tx, name: str) -> Any:
- """getattr(self, name) returning a python constant"""
- raise NotImplementedError
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- """getattr(self, name) returning a new variable"""
- value = self.const_getattr(tx, name)
- if not variables.ConstantVariable.is_literal(value):
- raise NotImplementedError
- source = None
- if self.source:
- source = AttrSource(self.source, name)
- return variables.ConstantVariable.create(value, source=source)
- def is_proxy(self):
- try:
- self.as_proxy()
- return True
- except NotImplementedError:
- return False
- def as_proxy(self):
- raise NotImplementedError(str(self))
- def maybe_fx_node(self):
- try:
- proxy = self.as_proxy()
- import torch.fx
- if isinstance(proxy, torch.fx.Proxy):
- return proxy.node
- return None
- except NotImplementedError:
- return None
- def reconstruct(self, codegen):
- raise NotImplementedError
- def can_reconstruct(self, tx):
- """If it is possible to reconstruct the Python object this
- VariableTracker represents."""
- assert tx is tx.output.root_tx, "Only root tx can reconstruct"
- try:
- from ..codegen import PyCodegen
- cg = PyCodegen(tx)
- self.reconstruct(cg)
- return True
- except NotImplementedError:
- return False
- def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
- raise NotImplementedError
- def has_unpack_var_sequence(self, tx) -> bool:
- try:
- self.unpack_var_sequence(tx)
- return True
- except NotImplementedError:
- return False
- def inspect_parameter_names(self) -> List[str]:
- unimplemented(f"inspect_parameter_names: {self}")
- def call_hasattr(self, tx, name: str) -> "VariableTracker":
- unimplemented(f"hasattr {self.__class__.__name__} {name}")
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- unimplemented(f"call_function {self} {args} {kwargs}")
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__len__" and self.has_unpack_var_sequence(tx):
- assert not (args or kwargs)
- return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx)))
- elif (
- name == "__getattr__"
- and len(args) == 1
- and args[0].is_python_constant()
- and not kwargs
- ):
- return self.var_getattr(tx, args[0].as_python_constant())
- unimplemented(f"call_method {self} {name} {args} {kwargs}")
- def set_name_hint(self, name):
- pass
- def realize(self) -> "VariableTracker":
- """Used by LazyVariableTracker to build the real VariableTracker"""
- return self
- def unwrap(self) -> "VariableTracker":
- """Used by LazyVariableTracker to return the real VariableTracker if it already exists"""
- return self
- def is_realized(self):
- """Used by LazyVariableTracker to indicate an unrealized node"""
- return True
- def next_variable(self, tx):
- unimplemented(f"next({self})")
- def is_strict_mode(self, tx):
- return tx.strict_checks_fn and tx.strict_checks_fn(self)
- def __init__(
- self,
- *,
- source: Source = None,
- mutable_local: MutableLocal = None,
- ):
- super().__init__()
- self.source = source
- self.mutable_local = mutable_local
- def typestr(*objs):
- if len(objs) == 1:
- (obj,) = objs
- if isinstance(obj, VariableTracker):
- return str(obj)
- else:
- return type(obj).__name__
- else:
- return " ".join(map(typestr, objs))
|