distributed.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # mypy: ignore-errors
  2. import functools
  3. import inspect
  4. from typing import Dict, List
  5. import torch
  6. from ...fx.experimental._backward_state import BackwardState
  7. from .. import compiled_autograd, variables
  8. from .._trace_wrapped_higher_order_op import trace_wrapped
  9. from ..exc import unimplemented
  10. from ..external_utils import call_module_hooks_from_backward_state
  11. from ..guards import GuardBuilder, install_guard
  12. from ..source import AttrSource
  13. from ..utils import istype
  14. from .base import VariableTracker
  15. from .constant import ConstantVariable
  16. class DistributedVariable(VariableTracker):
  17. """
  18. The base distributed variable that encapsulates common methods
  19. for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
  20. Concrete distributed objects could inherit this class and add object
  21. specific logic.
  22. i.e. It provides the check on the distributed package existance
  23. and hold the tracking value for the corresponding distributed object.
  24. """
  25. def __init__(self, value, **kwargs):
  26. super().__init__(**kwargs)
  27. if not DistributedVariable.is_available():
  28. unimplemented("torch.distributed package is not available!")
  29. self.value = value
  30. def python_type(self):
  31. return type(self.value)
  32. @staticmethod
  33. def is_available():
  34. # check if the distributed package is available or not
  35. return torch.distributed.is_available()
  36. def is_from_local(value):
  37. if not DistributedVariable.is_available():
  38. return False
  39. from torch.distributed._tensor import DTensor
  40. return inspect.isfunction(value) and value is DTensor.from_local
  41. def is_constant_pg_functions(value):
  42. if not DistributedVariable.is_available():
  43. return False
  44. from torch.distributed.distributed_c10d import (
  45. _get_group_size_by_name,
  46. _get_group_tag,
  47. _rank_not_in_group,
  48. _resolve_group_name_by_ranks_and_tag,
  49. get_process_group_ranks,
  50. )
  51. constant_processgroup_functions = [
  52. _get_group_size_by_name,
  53. _get_group_tag,
  54. _rank_not_in_group,
  55. get_process_group_ranks,
  56. _resolve_group_name_by_ranks_and_tag,
  57. ]
  58. return inspect.isfunction(value) and value in constant_processgroup_functions
  59. class WorldMetaClassVariable(DistributedVariable):
  60. """
  61. Tracks torch.distributed.GroupMember and torch.distributed.group, which are
  62. instances of the metaclass _WorldMeta.
  63. """
  64. @classmethod
  65. def is_group_member_type(cls, value):
  66. if not cls.is_available():
  67. return False
  68. from torch.distributed.distributed_c10d import _WorldMeta
  69. return type(value) is _WorldMeta
  70. def var_getattr(self, tx, name: str) -> VariableTracker:
  71. if name == "WORLD":
  72. source = AttrSource(base=self.source, member="WORLD")
  73. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  74. return ProcessGroupVariable(self.value.WORLD)
  75. return super().var_getattr(tx, name)
  76. class PlacementClassVariable(DistributedVariable):
  77. @staticmethod
  78. def is_placement_type(value):
  79. # we can't rely on importing/accessing torch distributed, it is not always built.
  80. if not DistributedVariable.is_available():
  81. return False
  82. from torch.distributed._tensor.placement_types import Placement
  83. return type(value) is type and issubclass(value, Placement)
  84. def as_python_constant(self):
  85. return self.value
  86. def call_function(
  87. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  88. ) -> "VariableTracker":
  89. if (
  90. inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
  91. and self.source
  92. ):
  93. # NOTE: we don't need to track mutations to the placement class as they
  94. # suppose to be immutable.
  95. new_obj = object.__new__(self.value)
  96. var = PlacementVariable(new_obj)
  97. if inspect.getattr_static(self.value, "__init__", None):
  98. var.call_method(tx, "__init__", args, kwargs)
  99. return var
  100. return super().call_function(tx, args, kwargs)
  101. class PlacementVariable(DistributedVariable):
  102. @staticmethod
  103. def is_placement(value):
  104. # we can't rely on importing/accessing torch distributed, it is not always built.
  105. if not DistributedVariable.is_available():
  106. return False
  107. from torch.distributed._tensor.placement_types import Placement
  108. return isinstance(value, Placement)
  109. def as_python_constant(self):
  110. return self.value
  111. def var_getattr(self, tx, name: str) -> VariableTracker:
  112. if name == "dim":
  113. return ConstantVariable.create(self.value.dim)
  114. return super().var_getattr(tx, name)
  115. def call_method(
  116. self,
  117. tx,
  118. name,
  119. args: "List[VariableTracker]",
  120. kwargs: "Dict[str, VariableTracker]",
  121. ) -> "VariableTracker":
  122. from . import ConstantVariable
  123. # Placement types dynamo tracking only allows following methods
  124. # and __setattr__ is for case like `Shard(dim)` and methods.
  125. # Methods in the list must satisfy:
  126. # 1. Input arguments are constants and do not need to be guarded on;
  127. # 2. Output is constant with respect to their inputs
  128. constant_fold_functions = [
  129. "__init__",
  130. "__setattr__",
  131. "is_shard",
  132. "is_partial",
  133. "is_replicate",
  134. ]
  135. if name in constant_fold_functions:
  136. try:
  137. value_type = type(self.value)
  138. assert (
  139. inspect.getattr_static(value_type, "__getattr__", None) is None
  140. ), "no custom getattr allowed!"
  141. method = inspect.getattr_static(value_type, name)
  142. except AttributeError:
  143. method = None
  144. if method is object.__init__:
  145. return ConstantVariable.create(None)
  146. args = [x.as_python_constant() for x in args]
  147. kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  148. if name == "__setattr__":
  149. method(self.value, *args, **kwargs)
  150. return self
  151. constant_val = method(self.value, *args, **kwargs)
  152. return ConstantVariable.create(constant_val)
  153. return super().call_method(tx, name, args, kwargs)
  154. class DeviceMeshVariable(DistributedVariable):
  155. @staticmethod
  156. def is_device_mesh(value):
  157. # we can't rely on importing/accessing torch distributed, it is not always built.
  158. if not DistributedVariable.is_available():
  159. return False
  160. from torch.distributed.device_mesh import DeviceMesh
  161. return istype(value, DeviceMesh)
  162. def as_python_constant(self):
  163. return self.value
  164. def var_getattr(self, tx, name: str) -> VariableTracker:
  165. if name == "ndim":
  166. return ConstantVariable.create(self.value.ndim)
  167. if name == "device_type":
  168. return ConstantVariable.create(self.value.device_type)
  169. return super().var_getattr(tx, name)
  170. def call_method(
  171. self,
  172. tx,
  173. name,
  174. args: "List[VariableTracker]",
  175. kwargs: "Dict[str, VariableTracker]",
  176. ) -> "VariableTracker":
  177. if name == "size":
  178. const_args = [x.as_python_constant() for x in args]
  179. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  180. return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
  181. if name == "get_coordinate":
  182. return ConstantVariable.create(self.value.get_coordinate())
  183. if name == "get_group":
  184. return ConstantVariable.create(self.value.get_group())
  185. if name == "_get_or_create_default_group":
  186. return ProcessGroupVariable(self.value._get_or_create_default_group())
  187. return super().call_method(tx, name, args, kwargs)
  188. class ProcessGroupVariable(DistributedVariable):
  189. """
  190. We don't want a ProcessGroup object to end up in our output graph.
  191. But it's common for dynamo to intercept a PG that is then used to get info like
  192. rank() or world_size(), as well as passed to utility functions in distributed_c10d
  193. which desugar it into plain types like a ranklist and tag.
  194. For convenience and proper guarding, we construct a variable type.
  195. TODO: make it possible to use ProcessGroupVariable as input to simple functions
  196. like _expand_group without dynamo complaining about making a proxy for it.
  197. It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
  198. torch library functions are dealing with tensor-like types and would have proxies
  199. for their args.
  200. TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
  201. or just graph-break whenever one of our special cases is not hit?
  202. """
  203. def as_python_constant(self):
  204. return self.value
  205. def call_method(
  206. self,
  207. tx,
  208. name,
  209. args: "List[VariableTracker]",
  210. kwargs: "Dict[str, VariableTracker]",
  211. ) -> "VariableTracker":
  212. if name == "rank":
  213. return variables.ConstantVariable.create(self.value.rank())
  214. if name == "size":
  215. return variables.ConstantVariable.create(self.value.size())
  216. return super().call_method(tx, name, args, kwargs)
  217. def var_getattr(self, tx, name):
  218. if name == "group_name":
  219. return variables.ConstantVariable.create(self.value.group_name)
  220. if name in ["rank", "size"]:
  221. return variables.LambdaVariable(
  222. lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
  223. )
  224. # TODO should this just raise unimplemented?
  225. return super().var_getattr(tx, name)
  226. @staticmethod
  227. def is_process_group(value):
  228. # we can't rely on importing/accessing torch distributed, it is not always built.
  229. if not DistributedVariable.is_available():
  230. return False
  231. from torch._C._distributed_c10d import ProcessGroup
  232. from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
  233. return istype(value, (ProcessGroup, FakeProcessGroup))
  234. class BackwardHookVariable(VariableTracker):
  235. """
  236. Handles torch.utils.hooks.BackwardHook for module-level backward
  237. hooks.
  238. """
  239. @staticmethod
  240. def create(
  241. tx,
  242. module: VariableTracker,
  243. user_hooks: VariableTracker,
  244. user_pre_hooks: VariableTracker,
  245. ):
  246. if not compiled_autograd.compiled_autograd_enabled:
  247. unimplemented("module-level backwards hooks require compiled autograd")
  248. def _in_graph_bw_hooks(bw_state: BackwardState):
  249. """
  250. Rather than installing the user hooks in the graph (which
  251. don't survive AotAutograd), we install hooks that will call
  252. trace_wrapped in the backward pass that CompiledAutograd
  253. can turn into actual hook calls.
  254. """
  255. return torch.utils.hooks.BackwardHook(
  256. None,
  257. (
  258. functools.partial(
  259. trace_wrapped,
  260. fn=call_module_hooks_from_backward_state,
  261. bw_state=bw_state,
  262. hooks_name=user_hooks_name,
  263. module_name=module_name,
  264. ),
  265. ),
  266. (
  267. functools.partial(
  268. trace_wrapped,
  269. fn=call_module_hooks_from_backward_state,
  270. bw_state=bw_state,
  271. hooks_name=user_pre_hooks_name,
  272. module_name=module_name,
  273. ),
  274. ),
  275. )
  276. module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
  277. user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
  278. user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
  279. proxy = tx.output.create_proxy(
  280. "call_function",
  281. _in_graph_bw_hooks,
  282. (bw_state_proxy,),
  283. {},
  284. )
  285. proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
  286. return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
  287. def __init__(
  288. self,
  289. proxy: torch.fx.Proxy,
  290. module: VariableTracker,
  291. user_hooks: VariableTracker,
  292. user_pre_hooks: VariableTracker,
  293. **options,
  294. ):
  295. super().__init__(**options)
  296. self.proxy = proxy
  297. self.module = module
  298. self.user_hooks = user_hooks
  299. self.user_pre_hooks = user_pre_hooks
  300. def as_proxy(self):
  301. return self.proxy
  302. def call_method(
  303. self,
  304. tx,
  305. name,
  306. args: List[VariableTracker],
  307. kwargs: Dict[str, VariableTracker],
  308. ) -> VariableTracker:
  309. if name in ("setup_input_hook", "setup_output_hook"):
  310. return self._setup_hook(tx, name, *args, **kwargs)
  311. return super().call_method(tx, name, args, kwargs)
  312. def _setup_hook(self, tx, hook_method_name, args):
  313. from .builder import wrap_fx_proxy
  314. return wrap_fx_proxy(
  315. tx,
  316. tx.output.create_proxy(
  317. "call_method",
  318. hook_method_name,
  319. (self.as_proxy(), args.as_proxy()),
  320. {},
  321. ),
  322. )