| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205 |
- # mypy: ignore-errors
- import collections
- import dataclasses
- import functools
- import inspect
- import itertools
- import re
- import sys
- import types
- from typing import Dict, List
- import torch._C
- import torch._numpy as tnp
- import torch.utils._pytree as pytree
- from .. import config, variables
- from ..bytecode_transformation import create_call_function, create_instruction
- from ..exc import unimplemented
- from ..guards import GuardBuilder, install_guard
- from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
- from ..utils import (
- check_unspec_or_constant_args,
- identity,
- is_tensor_base_attr_getter,
- proxy_args_kwargs,
- set_example_value,
- )
- from .base import VariableTracker
- from .functions import NestedUserFunctionVariable, UserFunctionVariable
- from .user_defined import is_standard_setattr, UserDefinedObjectVariable
- class SuperVariable(VariableTracker):
- _nonvar_fields = {
- "specialized",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
- super().__init__(**kwargs)
- # typevar is the fist argument to super(). In the case where no argument
- # is provided to super(), it is the __class__ object where
- # the super() function is being called
- self.typevar = typevar
- # objvar here must be an instance or subtype of typevar.
- # In the case where super() is called without arguments, it is the first argument
- # to the current function where super() is called from (self for regular method,
- # cls for a classmethod)
- self.objvar = objvar
- self.specialized = specialized # directly get attr from self.typevar if true
- def reconstruct(self, codegen):
- codegen(variables.BuiltinVariable(super))
- codegen(self.typevar)
- if self.objvar is not None:
- codegen(self.objvar)
- codegen.extend_output(create_call_function(2, True))
- else:
- codegen.extend_output(create_call_function(1, True))
- def _resolved_getattr_and_source(self, tx, name):
- assert self.objvar, "1-arg super not implemented"
- if self.specialized:
- return getattr(self.typevar.as_python_constant(), name)
- search_type = self.typevar.as_python_constant()
- # The rest of this function does two things:
- # - Walk the mro to find where the attribute comes from to be
- # able to provide accurate source
- # - Call the getattr to get the object
- # Find the class object, where the function lives.
- # When objvar is "self", use type(self), when objvar is "cls", use it as-is
- type_to_use = self.objvar.python_type()
- type_to_use_source = (
- TypeSource(self.objvar.source) if self.objvar.source else None
- )
- if issubclass(type_to_use, type):
- type_to_use = self.objvar.value
- type_to_use_source = self.objvar.source
- source = None
- if self.objvar.source is not None:
- # Walk the mro tuple to find out the actual class where the
- # attribute resides.
- search_mro = type_to_use.__mro__
- start_index = search_mro.index(search_type) + 1
- for index in range(start_index, len(search_mro)):
- if hasattr(search_mro[index], name):
- # Equivalent of something like type(L['self']).__mro__[1].attr_name
- source = AttrSource(
- GetItemSource(AttrSource(type_to_use_source, "__mro__"), index),
- name,
- )
- break
- # TODO(jansel): there is a small chance this could trigger user code, prevent that
- return getattr(super(search_type, type_to_use), name), source
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- # Check if getattr is a constant. If not, delay the actual work by
- # wrapping the result in GetAttrVariable. Mostly super is called with a
- # method, so most of the work is delayed to call_function.
- #
- # We could have just implemented a const_getattr. However, super is
- # special when it comes to finding sources. Compared to other VTs, super
- # requires the attr name to walk the mro and find the actual source (and
- # not just AttrSource).
- value, source = self._resolved_getattr_and_source(self, name)
- if not variables.ConstantVariable.is_literal(value):
- return GetAttrVariable(self, name)
- if source:
- install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
- return variables.ConstantVariable.create(value, source=source)
- return variables.ConstantVariable.create(value)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- inner_fn, source = self._resolved_getattr_and_source(self, name)
- if inner_fn is object.__init__:
- return LambdaVariable(identity)
- elif inner_fn is torch.nn.Module.__init__:
- objvar = self.objvar
- from ..side_effects import AttributeMutationNew
- if (
- isinstance(objvar, variables.UserDefinedObjectVariable)
- and isinstance(objvar.mutable_local, AttributeMutationNew)
- and not (args or kwargs)
- ):
- tx.output.side_effects.store_attr(
- objvar,
- "__call_nn_module_init",
- variables.ConstantVariable.create(True),
- )
- return variables.ConstantVariable.create(None)
- else:
- unimplemented("super() nn.Module.__init__")
- elif isinstance(inner_fn, types.FunctionType):
- return variables.UserFunctionVariable(
- inner_fn, source=source
- ).call_function(tx, [self.objvar] + args, kwargs)
- elif isinstance(inner_fn, types.MethodType):
- return variables.UserMethodVariable(
- inner_fn.__func__, self.objvar, source=source
- ).call_function(tx, args, kwargs)
- elif (
- inner_fn is collections.OrderedDict.__getitem__
- and isinstance(self.objvar, variables.UserDefinedObjectVariable)
- and self.objvar.source
- and len(args) == 1
- and len(kwargs) == 0
- and args[0].is_python_constant()
- ):
- from .builder import VariableBuilder
- key = args[0].as_python_constant()
- return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
- collections.OrderedDict.__getitem__(self.objvar.value, key)
- )
- elif inner_fn in (
- collections.OrderedDict.__setitem__,
- object.__setattr__,
- ) and isinstance(self.objvar, variables.CustomizedDictVariable):
- assert not kwargs and len(args) == 2
- return super(variables.CustomizedDictVariable, self.objvar).call_method(
- tx, "__setitem__", args, kwargs
- )
- elif is_standard_setattr(inner_fn) and isinstance(
- self.objvar, UserDefinedObjectVariable
- ):
- return self.objvar.method_setattr_standard(tx, *args, **kwargs)
- unimplemented(f"non-function or method super: {inner_fn}")
- class ExceptionVariable(VariableTracker):
- def __init__(self, exc_type, args, **kwargs):
- super().__init__(**kwargs)
- self.exc_type = exc_type
- self.args = args
- def reconstruct(self, codegen):
- codegen.load_import_from("builtins", self.exc_type.__name__)
- codegen.foreach(self.args)
- codegen.call_function(len(self.args), True)
- class UnknownVariable(VariableTracker):
- """
- It could be anything!
- """
- class DelayGraphBreakVariable(UnknownVariable):
- """
- Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
- """
- class ComptimeVariable(VariableTracker):
- """
- This variable is special, it lets you execute arbitrary code at
- Dynamo compile time
- """
- def reconstruct(self, codegen):
- raise NotImplementedError("comptime is special form")
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- from ..comptime import comptime
- # To support the comptime.print_graph convenience accessors
- from .functions import UserFunctionVariable
- return UserFunctionVariable(
- getattr(comptime, name), source=AttrSource(self.source, name)
- )
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- from ..comptime import ComptimeContext
- # TODO: support an expression form as well
- assert not kwargs
- # Second argument is runtime lambda, ignored
- assert len(args) <= 2
- fn = args[0]
- if isinstance(fn, UserFunctionVariable):
- fn.get_function()(ComptimeContext(tx))
- elif isinstance(fn, NestedUserFunctionVariable):
- # We have to manually bind the freevars ourselves
- code = fn.get_code()
- assert not fn.closure, (
- "comptime function must not have free variables, "
- f"but these variables were free: {code.co_freevars}"
- )
- func = types.FunctionType(
- code,
- fn.f_globals,
- fn.fn_name.as_python_constant(),
- tuple(fn.defaults.items) if fn.defaults else None,
- # We could automatically promote free variables into
- # ComptimeVar but this is confusing if you access
- # a free variable that we actually DO have the runtime
- # value for
- # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
- tuple(),
- )
- func(ComptimeContext(tx))
- else:
- raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
- return variables.ConstantVariable.create(None)
- class ClosureVariable(UnknownVariable):
- _nonvar_fields = {
- "name",
- *UnknownVariable._nonvar_fields,
- }
- def __init__(self, name, **kwargs):
- super().__init__(**kwargs)
- self.name = name
- def reconstruct(self, codegen):
- codegen.append_output(codegen.create_load_closure(self.name))
- # closure variable created by an inlined function
- class InlinedClosureVariable(UnknownVariable):
- _nonvar_fields = {
- "name",
- *UnknownVariable._nonvar_fields,
- }
- def __init__(self, name, **kwargs):
- super().__init__(**kwargs)
- self.name = name
- def reconstruct(self, codegen):
- codegen.append_output(codegen.create_load_closure(self.name))
- class NewCellVariable(VariableTracker):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- class NewGlobalVariable(VariableTracker):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- class InspectSignatureVariable(VariableTracker):
- """represents inspect.signature(...)"""
- @staticmethod
- def create(callable, **kwargs):
- if kwargs:
- unimplemented(f"inspect.signature with {kwargs}")
- return InspectSignatureVariable(callable)
- def __init__(self, inspected: VariableTracker, **kwargs):
- super().__init__(**kwargs)
- self.inspected = inspected
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- if name == "parameters":
- return variables.ConstDictVariable(
- {
- variables.ConstantVariable.create(name): InspectParameterVariable()
- for name in self.inspected.inspect_parameter_names()
- },
- user_cls=dict,
- )
- return super().var_getattr(tx, name)
- class InspectParameterVariable(VariableTracker):
- """This is not implemented, if used will graph break."""
- pass
- def produce_trampoline_autograd_apply(fn_cls):
- def trampoline_autograd_apply(*args, **kwargs):
- return fn_cls.apply(*args, **kwargs)
- trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
- return trampoline_autograd_apply
- class AutogradFunctionVariable(VariableTracker):
- """represents a torch.autograd.Function subclass"""
- _nonvar_fields = {
- "fn_cls",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, fn_cls, **kwargs):
- super().__init__(**kwargs)
- self.fn_cls = fn_cls
- def call_apply(self, tx, args, kwargs):
- requires_grad = False
- def visit(node):
- nonlocal requires_grad
- if isinstance(node, variables.TensorVariable):
- if node.requires_grad is not False:
- requires_grad = True
- if isinstance(node, variables.NNModuleVariable):
- if node.is_training(tx):
- requires_grad = True
- VariableTracker.visit(visit, (args, kwargs))
- if (
- requires_grad
- and torch.is_grad_enabled()
- and config.capture_autograd_function
- ):
- from torch._functorch.autograd_function import (
- autograd_function_forward_rewritten,
- )
- from torch.autograd.function import _is_setup_context_defined
- forward_fn = self.fn_cls.forward
- is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
- if is_setup_ctx_defined:
- # If setup_context is defined, we generate a new forward function which includes
- # the original forward and setup_context function, and trace the new forward function.
- forward_fn = autograd_function_forward_rewritten(
- self.fn_cls.forward, self.fn_cls.setup_context
- )
- vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
- if vjp_fn is not torch.autograd.Function.vjp:
- unimplemented("NYI - User defind vjp")
- jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
- if jvp_fn is not torch.autograd.Function.jvp:
- unimplemented("NYI - User defind jvp")
- from .higher_order_ops import AutogradFunctionApplyVariable
- source = self.source
- if source is None:
- source = AttrSource(
- tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
- )
- val = AutogradFunctionApplyVariable(
- forward_fn,
- self.fn_cls.backward,
- source,
- source=AttrSource(source, member="apply"),
- ).call_function(tx, args, kwargs)
- # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
- # the forward function, as we don't want to generate guards for new_forward.__closure__
- # if forward is rewritten by autograd_function_forward_rewritten.
- # But we still need to generate correct guards for the original forward and setup_context
- # functions, so we have to add guards manually.
- if self.source:
- fwd_src = AttrSource(self.source, "forward")
- install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
- if is_setup_ctx_defined:
- setup_ctx_src = AttrSource(self.source, "setup_context")
- install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
- return val
- if self.source:
- source = AttrSource(self.source, "forward")
- else:
- source = None
- fn = self.fn_cls.forward
- ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
- args = [ctx, *args]
- if isinstance(fn, types.FunctionType):
- return variables.UserFunctionVariable(fn, source=source).call_function(
- tx, args, kwargs
- )
- elif isinstance(fn, types.MethodType):
- return variables.UserMethodVariable(
- fn.__func__,
- variables.UserDefinedClassVariable(self.fn_cls),
- source=source,
- ).call_function(tx, args, kwargs)
- else:
- unimplemented(
- f"non-function or method in subclass of torch.autograd.Function: {fn}"
- )
- def call_backward(self, tx, args, kwargs):
- fn = self.fn_cls.backward
- self.source = AttrSource(self.source, "backward")
- assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
- assert isinstance(fn, types.FunctionType)
- return variables.UserFunctionVariable(fn, source=self.source).call_function(
- tx, args, kwargs
- )
- def call_function(self, tx, args, kwargs):
- return AutogradFunctionVariable(self.fn_cls)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ):
- from ..trace_rules import is_callable_allowed
- from .builder import wrap_fx_proxy
- if name == "apply":
- if is_callable_allowed(self.fn_cls):
- trampoline_autograd_apply = produce_trampoline_autograd_apply(
- self.fn_cls
- )
- return wrap_fx_proxy(
- tx=tx,
- proxy=tx.output.create_proxy(
- "call_function",
- trampoline_autograd_apply,
- *proxy_args_kwargs(args, kwargs),
- ),
- )
- else:
- return self.call_apply(tx, args, kwargs)
- elif name == "backward":
- return self.call_backward(tx, args, kwargs)
- else:
- from .. import trace_rules
- source = AttrSource(self.source, name) if self.source is not None else None
- try:
- obj = inspect.getattr_static(self.fn_cls, name)
- except AttributeError:
- obj = None
- if isinstance(obj, staticmethod):
- func = obj.__get__(self.fn_cls)
- if source is not None:
- return (
- trace_rules.lookup(func)
- .create_with_source(func, source=source)
- .call_function(tx, args, kwargs)
- )
- else:
- return trace_rules.lookup(func)(func).call_function(
- tx, args, kwargs
- )
- elif isinstance(obj, classmethod):
- return variables.UserMethodVariable(
- obj.__func__, self, source=source
- ).call_function(tx, args, kwargs)
- else:
- unimplemented(f"Unsupported method: {name}")
- @dataclasses.dataclass
- class SavedTensorBox:
- tensors: List[VariableTracker] = dataclasses.field(default_factory=list)
- class AutogradFunctionContextVariable(UserDefinedObjectVariable):
- """
- Tracks an autograd.Function() context using mutation tracking in side_effects.py
- """
- _nonvar_fields = {
- "proxy",
- "inference",
- "saved_tensors",
- *UserDefinedObjectVariable._nonvar_fields,
- }
- def __init__(
- self,
- value,
- value_type=None,
- inference=False,
- proxy=None,
- saved_tensors=None,
- needs_input_grad=None,
- **kwargs,
- ):
- super().__init__(value=value, value_type=value_type, **kwargs)
- self.inference = inference
- self.proxy = proxy
- self.saved_tensors = saved_tensors
- self.needs_input_grad = needs_input_grad
- @staticmethod
- def create(tx, args=None, kwargs=None):
- needs_input_grad = None
- if args and not kwargs:
- needs_input_grad = tuple(
- isinstance(x, variables.TensorVariable) and x.requires_grad
- for x in args
- )
- proxy = tx.output.create_proxy(
- "call_function", torch.autograd.function.FunctionCtx, tuple(), {}
- )
- out = tx.output.side_effects.track_object_new(
- None,
- torch.autograd.function.FunctionCtx,
- functools.partial(
- AutogradFunctionContextVariable,
- inference=True,
- proxy=proxy,
- saved_tensors=SavedTensorBox(),
- needs_input_grad=needs_input_grad,
- ),
- {},
- )
- set_example_value(proxy.node, out.value)
- return out
- def as_proxy(self):
- if self.proxy is None:
- unimplemented("proxy not set")
- return self.proxy
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__setattr__":
- return super().call_method(tx, name, args, kwargs)
- if name != "save_for_backward":
- unimplemented(f"autograd.Function context method: {name}")
- if self.saved_tensors is None:
- unimplemented(
- "save_for_backward only supported on a newly constructed FunctionCtx"
- )
- if not self.inference:
- assert self.source and not kwargs
- tx.output.side_effects.track_save_for_backward(self, args)
- # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
- if len(self.saved_tensors.tensors) > 0:
- self.saved_tensors.tensors = []
- for arg in args:
- self.saved_tensors.tensors.append(arg)
- return variables.ConstantVariable.create(None)
- def var_getattr(self, tx, name):
- if name == "save_for_backward":
- return LambdaVariable(
- lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
- )
- if name == "saved_tensors" and self.saved_tensors is not None:
- return variables.TupleVariable(list(self.saved_tensors.tensors))
- if name == "needs_input_grad":
- if self.needs_input_grad is not None:
- return variables.ConstantVariable.create(self.needs_input_grad)
- if self.source:
- from .builder import VariableBuilder
- return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
- self.value.needs_input_grad
- )
- return super().var_getattr(tx, name)
- class LambdaVariable(VariableTracker):
- def __init__(self, fn, **kwargs):
- super().__init__(**kwargs)
- self.fn = fn
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- return self.fn(*args, **kwargs)
- class GetAttrVariable(VariableTracker):
- _nonvar_fields = {
- "name",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, obj, name, **kwargs):
- super().__init__(**kwargs)
- assert isinstance(obj, VariableTracker)
- assert isinstance(name, str)
- self.obj = obj
- self.name = name
- def __str__(self):
- return f"{self.__class__.__name__}({self.obj}, {self.name})"
- @staticmethod
- def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
- return getattr(base_proxy, attr)
- def as_proxy(self):
- return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
- def const_getattr(self, tx, name):
- if not isinstance(self.obj, variables.NNModuleVariable):
- raise NotImplementedError
- step1 = tx.output.get_submodule(self.obj.module_key)
- if self.name not in step1.__dict__:
- raise NotImplementedError
- step2 = inspect.getattr_static(step1, self.name)
- if name not in step2.__dict__:
- raise NotImplementedError
- return inspect.getattr_static(step2, name)
- def reconstruct(self, codegen):
- codegen(self.obj)
- codegen.extend_output(codegen.create_load_attrs(self.name))
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- return self.obj.call_method(tx, self.name, args, kwargs)
- def call_method(
- self,
- tx,
- name,
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
- ) -> VariableTracker:
- if (
- name in ("__getitem__", "get")
- and self.name == "__dict__"
- and not kwargs
- and args[0].is_python_constant()
- and isinstance(
- self.obj,
- (variables.UserDefinedObjectVariable, variables.NNModuleVariable),
- )
- ):
- obj = self.obj
- key = args[0].as_python_constant()
- if obj.has_key_in_generic_dict(tx, key):
- # redirect to var_getattr on the original obj
- return obj.var_getattr(tx, key)
- # Return the default value for get
- if name == "get":
- if len(args) == 2:
- return args[1]
- else:
- return variables.ConstantVariable(None)
- elif (
- name == "__contains__"
- and self.name == "__dict__"
- and len(args) == 1
- and args[0].is_python_constant()
- and not kwargs
- and isinstance(
- self.obj,
- (variables.UserDefinedObjectVariable, variables.NNModuleVariable),
- )
- ):
- obj = self.obj
- key = args[0].as_python_constant()
- if obj.has_key_in_generic_dict(tx, key):
- return variables.ConstantVariable(True)
- else:
- return variables.ConstantVariable(False)
- return super().call_method(tx, name, args, kwargs)
- class MethodWrapperVariable(VariableTracker):
- def __init__(self, method_wrapper, **kwargs):
- super().__init__(**kwargs)
- self.method_wrapper = method_wrapper
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
- args[0], variables.TensorVariable
- ):
- assert len(args) == 1 and len(kwargs) == 0
- return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
- super().call_function(tx, args, kwargs)
- def is_python_constant(self):
- return True
- def as_python_constant(self):
- return self.method_wrapper
- class GetSetDescriptorVariable(VariableTracker):
- def __init__(self, desc, **kwargs):
- super().__init__(**kwargs)
- self.desc = desc
- def var_getattr(self, tx, name):
- if name == "__get__" and self.source:
- from .builder import VariableBuilder
- return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
- self.desc.__get__
- )
- else:
- return super().var_getattr(tx, name)
- def is_python_constant(self):
- return True
- def as_python_constant(self):
- return self.desc
- class PythonModuleVariable(VariableTracker):
- _nonvar_fields = {
- "value",
- "is_torch",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, value: types.ModuleType, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
- def python_type(self):
- return types.ModuleType
- def as_python_constant(self):
- return self.value
- def __repr__(self):
- return f"PythonModuleVariable({self.value})"
- def call_hasattr(self, tx, name):
- if self.is_torch:
- result = hasattr(self.value, name)
- return variables.ConstantVariable.create(result)
- return super().call_hasattr(tx, name)
- class TypingVariable(VariableTracker):
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__getitem__" and len(args) == 1:
- return variables.ConstantVariable.create(
- self.value[args[0].as_python_constant()],
- )
- unimplemented("typing")
- def python_type(self):
- return type(self.value)
- def as_python_constant(self):
- return self.value
- @functools.lru_cache(maxsize=1)
- def get_np_to_tnp_map():
- from ..utils import NP_TO_TNP_MODULE
- np_fn_to_tnp_fn = {}
- for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
- for fn_name, tnp_fn in tnp_mod.__dict__.items():
- if callable(tnp_fn):
- # some internal details do leak from tnp
- # which are not part of numpy API.
- if np_fn := getattr(np_mod, fn_name, None):
- np_fn_to_tnp_fn[np_fn] = tnp_fn
- return np_fn_to_tnp_fn
- class NumpyVariable(VariableTracker):
- """
- Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
- """
- constant_fold_functions = (tnp.issubdtype,)
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- @classmethod
- def can_constant_fold_through(cls, fn):
- mod = fn.__module__.split(".")
- assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
- return fn in cls.constant_fold_functions
- @classmethod
- def get_constant_collection_for_func(cls, fn):
- mod = fn.__module__.split(".")
- assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
- return np_constant_collections_map.get(fn, None)
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if not config.trace_numpy:
- unimplemented(f"numpy.{self.value}()")
- from ..utils import numpy_to_tensor_wrapper
- from .tensor import NumpyNdarrayVariable
- func = get_np_to_tnp_map().get(self.value)
- if func is None:
- unimplemented(
- f"Can't find numpy function {self.value} in torch._numpy. "
- " Please file an issue to request support for this function."
- )
- # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
- if (
- collection_variable_typ := self.get_constant_collection_for_func(func)
- ) is not None:
- try:
- return collection_variable_typ(
- self.value(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- )
- except NotImplementedError:
- unimplemented(
- f"{self.value.__name__} with non-const args: {args} {kwargs}"
- )
- else:
- if (
- func.__module__ == "torch._numpy.random"
- and config.use_numpy_random_stream
- ):
- msg = f"delegate '{func.__qualname__}' to NumPy itself via "
- msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
- unimplemented(msg)
- args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
- if self.can_constant_fold_through(func) and (
- check_unspec_or_constant_args(args, kwargs)
- ):
- # constant fold
- return variables.ConstantVariable.create(
- self.as_python_constant()(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- ),
- )
- # TODO Add all the functions that go from constants to constants to can_constant_fold_through
- proxy = tx.output.create_proxy(
- "call_function",
- numpy_to_tensor_wrapper(func),
- *proxy_args_kwargs(args, kwargs),
- )
- return NumpyNdarrayVariable.create(tx, proxy)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- unimplemented("numpy")
- def python_type(self):
- return type(self.value)
- def as_python_constant(self):
- return self.value
- def as_proxy(self):
- if config.trace_numpy and isinstance(self.value, type):
- # This handles numpy dtype attributes such as np.float32
- # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
- # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
- return self.value.__name__
- return super().as_proxy()
- # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
- class NullVariable(VariableTracker):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def __str__(self):
- return "NullVariable"
- def reconstruct(self, codegen):
- if sys.version_info < (3, 11):
- unimplemented("cannot reconstruct NullVariable in < Python 3.11")
- codegen.append_output(create_instruction("PUSH_NULL"))
- class DeletedVariable(VariableTracker):
- """Marker used to implement delattr()"""
- class StringFormatVariable(VariableTracker):
- """
- Represents a call to str.format(), we delay calling format until after the graph.
- """
- _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
- @classmethod
- def create(cls, format_string, sym_args, sym_kwargs):
- if all(
- x.is_python_constant()
- for x in itertools.chain(sym_args, sym_kwargs.values())
- ):
- return variables.ConstantVariable.create(
- format_string.format(
- *[v.as_python_constant() for v in sym_args],
- **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
- )
- )
- return cls(format_string, list(sym_args), dict(sym_kwargs))
- def __init__(self, format_string, sym_args, sym_kwargs, **kwargs):
- super().__init__(**kwargs)
- assert isinstance(format_string, str)
- self.format_string = format_string
- self.sym_args = sym_args
- self.sym_kwargs = sym_kwargs
- def __repr__(self):
- return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
- def reconstruct(self, codegen):
- if sys.version_info >= (3, 11):
- codegen.append_output(create_instruction("PUSH_NULL"))
- codegen.append_output(codegen.create_load_const(self.format_string))
- codegen.append_output(codegen.create_load_attr("format"))
- codegen(variables.TupleVariable(self.sym_args))
- kwargs = {
- variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
- }
- codegen(variables.ConstDictVariable(kwargs))
- codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
- class DebuggingVariable(VariableTracker):
- """
- Represents a call to a debugging function like print(), or something
- registered to config.reorderable_logging_functions.
- """
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- @staticmethod
- def is_reorderable_logging_function(obj):
- return (
- callable(obj)
- and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
- and obj in torch._dynamo.config.reorderable_logging_functions
- )
- def call_function(self, tx, args, kwargs):
- if tx.export:
- # For export cases, we can just make debugging functions no-ops
- return
- if not self.can_reorder_logs(self.value, args, kwargs):
- unimplemented(
- f"Reordering debugging function {self.value} "
- f"with inputs {args} {kwargs} is not yet implemented."
- )
- tx.debug_locals.append((self, list(args)))
- def reconstruct(self, codegen):
- return self.source.reconstruct(codegen)
- @staticmethod
- def can_reorder_logs(fn, args, kwargs) -> True:
- """
- Run some additional checks for what sort of function calls can we
- actually reorder.
- """
- allowed_input_types = (
- variables.TensorVariable,
- variables.ConstantVariable,
- StringFormatVariable,
- )
- flat_args = pytree.tree_leaves([args, kwargs])
- for arg in flat_args:
- if not isinstance(arg, allowed_input_types):
- return False
- return True
- class LoggingLoggerVariable(VariableTracker):
- """
- Represents a call to any of logging.Logger methods
- """
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if tx.export:
- # For export cases, we can just make debugging functions no-ops
- return
- unimplemented("Logger not supported for non-export cases")
- class StopIterationVariable(VariableTracker):
- def __init__(self, args, **kwargs):
- super().__init__(**kwargs)
- self.args = args
- def reconstruct(self, codegen):
- codegen.load_import_from("builtins", "StopIteration")
- codegen.foreach(self.args)
- codegen.call_function(len(self.args), True)
- class ConstantLikeVariable(VariableTracker):
- """self.value is a compile-time constant, but not a literal"""
- _error_prefix = "ConstantLikeVariable"
- try:
- from numpy import (
- dtype as np_dtype,
- floating as np_floating,
- generic as np_generic,
- )
- except ImportError:
- np_floating = type("invalid_type", (), {})
- np_dtype = type("invalid_type", (), {})
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- def python_type(self):
- return type(self.value)
- def as_python_constant(self):
- return self.value
- def call_method(
- self,
- tx,
- name,
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
- ) -> VariableTracker:
- try:
- # we only support constant propagation for methods
- cargs = [x.as_python_constant() for x in args]
- ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- except NotImplementedError:
- unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
- result = getattr(self.value, name)(*cargs, **ckwargs)
- if variables.ConstantVariable.is_literal(result):
- return variables.ConstantVariable.create(result)
- if isinstance(result, re.Match):
- return ConstantRegexMatchVariable(result)
- unimplemented(f"{self._error_prefix}.{name}() -> {result}")
- def var_getattr(self, tx, name: str) -> VariableTracker:
- result = getattr(self.value, name)
- if isinstance(result, self.np_floating):
- result = float(result)
- if isinstance(result, self.np_dtype):
- return NumpyDTypeVariable(result)
- if isinstance(result, type) and issubclass(result, self.np_generic):
- # things like x.dtype.type
- return NumpyVariable(result)
- if variables.ConstantVariable.is_literal(result):
- return variables.ConstantVariable.create(result)
- return GetAttrVariable(self, name)
- class RegexPatternVariable(ConstantLikeVariable):
- _error_prefix = "re.Pattern"
- class ConstantRegexMatchVariable(ConstantLikeVariable):
- _error_prefix = "re.Match"
- class TorchVersionVariable(ConstantLikeVariable):
- _error_prefix = "torch.__version__"
- def __init__(self, **kwargs):
- kwargs.setdefault("value", torch.__version__)
- assert kwargs["value"] is torch.__version__
- super().__init__(**kwargs)
- class NumpyTypeInfoVariable(ConstantLikeVariable):
- _error_prefix = "np.iinfo/np.finfo"
- class NumpyDTypeVariable(ConstantLikeVariable):
- _error_prefix = "np.dtype[...]"
- def as_proxy(self):
- """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
- np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
- This also handles unsupported things nicely (i.e. structured arrays and object arrays).
- """
- return self.value.type.__name__
- np_constant_collections_map = {
- tnp.finfo: NumpyTypeInfoVariable,
- tnp.iinfo: NumpyTypeInfoVariable,
- tnp.dtype: NumpyDTypeVariable,
- }
|