| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- # mypy: ignore-errors
- import functools
- import inspect
- from typing import Dict, List
- import torch
- from ...fx.experimental._backward_state import BackwardState
- from .. import compiled_autograd, variables
- from .._trace_wrapped_higher_order_op import trace_wrapped
- from ..exc import unimplemented
- from ..external_utils import call_module_hooks_from_backward_state
- from ..guards import GuardBuilder, install_guard
- from ..source import AttrSource
- from ..utils import istype
- from .base import VariableTracker
- from .constant import ConstantVariable
- class DistributedVariable(VariableTracker):
- """
- The base distributed variable that encapsulates common methods
- for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
- Concrete distributed objects could inherit this class and add object
- specific logic.
- i.e. It provides the check on the distributed package existance
- and hold the tracking value for the corresponding distributed object.
- """
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- if not DistributedVariable.is_available():
- unimplemented("torch.distributed package is not available!")
- self.value = value
- def python_type(self):
- return type(self.value)
- @staticmethod
- def is_available():
- # check if the distributed package is available or not
- return torch.distributed.is_available()
- def is_from_local(value):
- if not DistributedVariable.is_available():
- return False
- from torch.distributed._tensor import DTensor
- return inspect.isfunction(value) and value is DTensor.from_local
- def is_constant_pg_functions(value):
- if not DistributedVariable.is_available():
- return False
- from torch.distributed.distributed_c10d import (
- _get_group_size_by_name,
- _get_group_tag,
- _rank_not_in_group,
- _resolve_group_name_by_ranks_and_tag,
- get_process_group_ranks,
- )
- constant_processgroup_functions = [
- _get_group_size_by_name,
- _get_group_tag,
- _rank_not_in_group,
- get_process_group_ranks,
- _resolve_group_name_by_ranks_and_tag,
- ]
- return inspect.isfunction(value) and value in constant_processgroup_functions
- class WorldMetaClassVariable(DistributedVariable):
- """
- Tracks torch.distributed.GroupMember and torch.distributed.group, which are
- instances of the metaclass _WorldMeta.
- """
- @classmethod
- def is_group_member_type(cls, value):
- if not cls.is_available():
- return False
- from torch.distributed.distributed_c10d import _WorldMeta
- return type(value) is _WorldMeta
- def var_getattr(self, tx, name: str) -> VariableTracker:
- if name == "WORLD":
- source = AttrSource(base=self.source, member="WORLD")
- install_guard(source.make_guard(GuardBuilder.ID_MATCH))
- return ProcessGroupVariable(self.value.WORLD)
- return super().var_getattr(tx, name)
- class PlacementClassVariable(DistributedVariable):
- @staticmethod
- def is_placement_type(value):
- # we can't rely on importing/accessing torch distributed, it is not always built.
- if not DistributedVariable.is_available():
- return False
- from torch.distributed._tensor.placement_types import Placement
- return type(value) is type and issubclass(value, Placement)
- def as_python_constant(self):
- return self.value
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if (
- inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
- and self.source
- ):
- # NOTE: we don't need to track mutations to the placement class as they
- # suppose to be immutable.
- new_obj = object.__new__(self.value)
- var = PlacementVariable(new_obj)
- if inspect.getattr_static(self.value, "__init__", None):
- var.call_method(tx, "__init__", args, kwargs)
- return var
- return super().call_function(tx, args, kwargs)
- class PlacementVariable(DistributedVariable):
- @staticmethod
- def is_placement(value):
- # we can't rely on importing/accessing torch distributed, it is not always built.
- if not DistributedVariable.is_available():
- return False
- from torch.distributed._tensor.placement_types import Placement
- return isinstance(value, Placement)
- def as_python_constant(self):
- return self.value
- def var_getattr(self, tx, name: str) -> VariableTracker:
- if name == "dim":
- return ConstantVariable.create(self.value.dim)
- return super().var_getattr(tx, name)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from . import ConstantVariable
- # Placement types dynamo tracking only allows following methods
- # and __setattr__ is for case like `Shard(dim)` and methods.
- # Methods in the list must satisfy:
- # 1. Input arguments are constants and do not need to be guarded on;
- # 2. Output is constant with respect to their inputs
- constant_fold_functions = [
- "__init__",
- "__setattr__",
- "is_shard",
- "is_partial",
- "is_replicate",
- ]
- if name in constant_fold_functions:
- try:
- value_type = type(self.value)
- assert (
- inspect.getattr_static(value_type, "__getattr__", None) is None
- ), "no custom getattr allowed!"
- method = inspect.getattr_static(value_type, name)
- except AttributeError:
- method = None
- if method is object.__init__:
- return ConstantVariable.create(None)
- args = [x.as_python_constant() for x in args]
- kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- if name == "__setattr__":
- method(self.value, *args, **kwargs)
- return self
- constant_val = method(self.value, *args, **kwargs)
- return ConstantVariable.create(constant_val)
- return super().call_method(tx, name, args, kwargs)
- class DeviceMeshVariable(DistributedVariable):
- @staticmethod
- def is_device_mesh(value):
- # we can't rely on importing/accessing torch distributed, it is not always built.
- if not DistributedVariable.is_available():
- return False
- from torch.distributed.device_mesh import DeviceMesh
- return istype(value, DeviceMesh)
- def as_python_constant(self):
- return self.value
- def var_getattr(self, tx, name: str) -> VariableTracker:
- if name == "ndim":
- return ConstantVariable.create(self.value.ndim)
- if name == "device_type":
- return ConstantVariable.create(self.value.device_type)
- return super().var_getattr(tx, name)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "size":
- const_args = [x.as_python_constant() for x in args]
- const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
- if name == "get_coordinate":
- return ConstantVariable.create(self.value.get_coordinate())
- if name == "get_group":
- return ConstantVariable.create(self.value.get_group())
- if name == "_get_or_create_default_group":
- return ProcessGroupVariable(self.value._get_or_create_default_group())
- return super().call_method(tx, name, args, kwargs)
- class ProcessGroupVariable(DistributedVariable):
- """
- We don't want a ProcessGroup object to end up in our output graph.
- But it's common for dynamo to intercept a PG that is then used to get info like
- rank() or world_size(), as well as passed to utility functions in distributed_c10d
- which desugar it into plain types like a ranklist and tag.
- For convenience and proper guarding, we construct a variable type.
- TODO: make it possible to use ProcessGroupVariable as input to simple functions
- like _expand_group without dynamo complaining about making a proxy for it.
- It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
- torch library functions are dealing with tensor-like types and would have proxies
- for their args.
- TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
- or just graph-break whenever one of our special cases is not hit?
- """
- def as_python_constant(self):
- return self.value
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "rank":
- return variables.ConstantVariable.create(self.value.rank())
- if name == "size":
- return variables.ConstantVariable.create(self.value.size())
- return super().call_method(tx, name, args, kwargs)
- def var_getattr(self, tx, name):
- if name == "group_name":
- return variables.ConstantVariable.create(self.value.group_name)
- if name in ["rank", "size"]:
- return variables.LambdaVariable(
- lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
- )
- # TODO should this just raise unimplemented?
- return super().var_getattr(tx, name)
- @staticmethod
- def is_process_group(value):
- # we can't rely on importing/accessing torch distributed, it is not always built.
- if not DistributedVariable.is_available():
- return False
- from torch._C._distributed_c10d import ProcessGroup
- from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
- return istype(value, (ProcessGroup, FakeProcessGroup))
- class BackwardHookVariable(VariableTracker):
- """
- Handles torch.utils.hooks.BackwardHook for module-level backward
- hooks.
- """
- @staticmethod
- def create(
- tx,
- module: VariableTracker,
- user_hooks: VariableTracker,
- user_pre_hooks: VariableTracker,
- ):
- if not compiled_autograd.compiled_autograd_enabled:
- unimplemented("module-level backwards hooks require compiled autograd")
- def _in_graph_bw_hooks(bw_state: BackwardState):
- """
- Rather than installing the user hooks in the graph (which
- don't survive AotAutograd), we install hooks that will call
- trace_wrapped in the backward pass that CompiledAutograd
- can turn into actual hook calls.
- """
- return torch.utils.hooks.BackwardHook(
- None,
- (
- functools.partial(
- trace_wrapped,
- fn=call_module_hooks_from_backward_state,
- bw_state=bw_state,
- hooks_name=user_hooks_name,
- module_name=module_name,
- ),
- ),
- (
- functools.partial(
- trace_wrapped,
- fn=call_module_hooks_from_backward_state,
- bw_state=bw_state,
- hooks_name=user_pre_hooks_name,
- module_name=module_name,
- ),
- ),
- )
- module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
- user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
- user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
- proxy = tx.output.create_proxy(
- "call_function",
- _in_graph_bw_hooks,
- (bw_state_proxy,),
- {},
- )
- proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
- return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
- def __init__(
- self,
- proxy: torch.fx.Proxy,
- module: VariableTracker,
- user_hooks: VariableTracker,
- user_pre_hooks: VariableTracker,
- **options,
- ):
- super().__init__(**options)
- self.proxy = proxy
- self.module = module
- self.user_hooks = user_hooks
- self.user_pre_hooks = user_pre_hooks
- def as_proxy(self):
- return self.proxy
- def call_method(
- self,
- tx,
- name,
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
- ) -> VariableTracker:
- if name in ("setup_input_hook", "setup_output_hook"):
- return self._setup_hook(tx, name, *args, **kwargs)
- return super().call_method(tx, name, args, kwargs)
- def _setup_hook(self, tx, hook_method_name, args):
- from .builder import wrap_fx_proxy
- return wrap_fx_proxy(
- tx,
- tx.output.create_proxy(
- "call_method",
- hook_method_name,
- (self.as_proxy(), args.as_proxy()),
- {},
- ),
- )
|