data_parallel.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. # mypy: allow-untyped-defs
  2. import operator
  3. from contextlib import contextmanager
  4. from enum import Enum
  5. from typing import Any, cast, Dict, List, Optional, Tuple
  6. import torch
  7. import torch.fx as fx
  8. import torch.library
  9. import torch.nn as nn
  10. import torch.utils._pytree as pytree
  11. from torch.distributed._spmd.batch_dim_utils import BatchDimAnalyzer
  12. from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
  13. from torch.distributed._tensor._op_schema import (
  14. OpStrategy,
  15. PlacementStrategy,
  16. StrategyType,
  17. TupleStrategy,
  18. )
  19. from torch.distributed._tensor._redistribute import redistribute_local_tensor
  20. from torch.distributed._tensor._utils import compute_local_shape
  21. from torch.distributed._tensor.placement_types import _Partial, DTensorSpec, Placement
  22. from torch.fx import GraphModule
  23. from torch.fx.experimental.proxy_tensor import make_fx
  24. from torch.fx.passes.shape_prop import _extract_tensor_metadata
  25. from torch.nn.utils._named_member_accessor import NamedMemberAccessor
  26. aten = torch.ops.aten
  27. # Dummy op used by data parallel to tag gradients.
  28. _spmd_lib_def = torch.library.Library("_spmd", "DEF")
  29. _spmd_lib_def.define("tag_grad(Tensor self) -> Tensor")
  30. _spmd_lib_impl = torch.library.Library("_spmd", "IMPL")
  31. _spmd_lib_impl.impl("tag_grad", lambda x: x, "CompositeExplicitAutograd")
  32. class DataParallelStyle(Enum):
  33. """This enum represents the style of the data-parallel operation.
  34. We have three types of Data Parallel style:
  35. 1. DEFAULT: the default data parallel style, which is to represent a mixed
  36. replicate and fully shard behavior. For each parameter that is able
  37. to be sharded evenly, we shard it, otherwise we would replicate the
  38. parameter. This style avoids potential padding if the parameters
  39. cannot be sharded evenly, but it would generate a mixed of all_reduce
  40. and reduce_scatter.
  41. 2. REPLICATE: the data parallel style that replicates all model parameters.
  42. This is similar to the behavior of DistributedDataParallel.
  43. 3. FULLY_SHARD: the data parallel style that shards all model parameters. This
  44. is similar to the behavior of FullyShardedDataParallel, the
  45. difference is that FullyShardedDataParallel (ZERO-3), which
  46. shards the model using FlatParameter based sharding,
  47. while this style shards each parameter into DTensor.
  48. """
  49. DEFAULT = 0
  50. REPLICATE = 1
  51. FULLY_SHARD = 2
  52. class NodeType(Enum):
  53. """NodeType is an enum that records the type of the tensors in the graph.
  54. This is used to determine the data parallel strategy.
  55. """
  56. PARAM = 0
  57. ACT = 1
  58. GRAD = 2
  59. STATE = 3
  60. NON_TENSOR = 4 # NON_TENSOR is to tag non tensor node (i.e. graph output)
  61. class DataParallelStrategy(OpStrategy):
  62. """DataParallelStrategy is a special case of OpStrategy that only records the "data parallel style" placement
  63. strategy for each fx Node.
  64. It takes a list of PlacementStrategy, where each PlacementStrategy describes
  65. one way to distribute the tensor and computation. In the DataParallel case,
  66. there're two possible ways to distribute the parameters:
  67. 1. replicate the parameter over a set of devices (DDP like behavior)
  68. 2. shard the parameter on its tensor dimension 0 over a set of devices
  69. (FSDP like behavior).
  70. In addition to the strategy list, we also need to:
  71. 1. `node_type`: record the type of each node in the graph, so that we can
  72. determine how to propagate in a data parallel fashion.
  73. 2. `reduce_over_batch` is specifically tied to data parallel as the loss
  74. calculation usually results in scalar tensor where it comes from a
  75. reduction over the batch dimension. We need to know this information
  76. so that we could keep the output as sharded.
  77. """
  78. def __init__(
  79. self,
  80. node_type: NodeType,
  81. strategy_list: List[PlacementStrategy],
  82. reduction_over_batch: bool = False,
  83. ):
  84. super().__init__(strategy_list)
  85. self.node_type = node_type
  86. self.reduction_over_batch = reduction_over_batch
  87. def __str__(self) -> str:
  88. return f"type: {self.node_type}, {super().__str__()}"
  89. @contextmanager
  90. def gradients_tagging(params: Dict[str, torch.Tensor]):
  91. """Tag the gradient of the parameters with a special tag, so that we can identify them during SPMD expansion.
  92. It's safe to trace those hooks and we would remove those nodes later.
  93. """
  94. tagging_hooks = []
  95. try:
  96. for p in params.values():
  97. h = p.register_hook(torch.ops._spmd.tag_grad)
  98. tagging_hooks.append(h)
  99. yield
  100. finally:
  101. # remove those hooks after tracing
  102. for h in tagging_hooks:
  103. h.remove()
  104. def _gen_shard_strategy(
  105. mesh: DeviceMesh, shard_dim: int, input_specs: Optional[List[DTensorSpec]] = None
  106. ) -> PlacementStrategy:
  107. """Util function to generate a shard strategy on shard_dim."""
  108. return PlacementStrategy(
  109. output_specs=DTensorSpec(mesh=mesh, placements=(Shard(shard_dim),)),
  110. input_specs=input_specs,
  111. )
  112. def _gen_replicate_strategy(
  113. mesh: DeviceMesh, input_specs: Optional[List[DTensorSpec]] = None
  114. ) -> PlacementStrategy:
  115. """Util function to generate a replicate strategy."""
  116. return PlacementStrategy(
  117. output_specs=DTensorSpec(mesh=mesh, placements=(Replicate(),)),
  118. input_specs=input_specs,
  119. )
  120. def _gen_partial_strategy(mesh: DeviceMesh) -> PlacementStrategy:
  121. """Util function to generate a partial strategy."""
  122. # NOTE: we use AVG by default, avg reduction is needed depending on
  123. # the loss function, for most loss function it should do
  124. # gradient averaging. There might be certain cases it should
  125. # not do gradient averaging (i.e. sum) but it's pretty rare.
  126. # TODO: Only NCCL supports AVG so using backend like Gloo would
  127. # crash, we should figure out a way to support avg reduction
  128. # for non-NCCL backend
  129. return PlacementStrategy(
  130. output_specs=DTensorSpec(mesh=mesh, placements=(_Partial("avg"),)),
  131. )
  132. def build_data_parallel_strategies(
  133. train_step_graph: GraphModule,
  134. num_params: int,
  135. num_states: int,
  136. mesh: DeviceMesh,
  137. batch_dim: int = 0,
  138. ) -> Dict[fx.Node, StrategyType]:
  139. """Loop through the train step graph and build the data parallel strategy for each fx Node."""
  140. activation_idx = num_params + num_states
  141. non_compute_ops = [
  142. aten.clone.default,
  143. aten.detach.default,
  144. aten.ones_like.default,
  145. aten.reshape.default,
  146. aten.t.default,
  147. aten.view.default,
  148. torch.ops._spmd.tag_grad.default,
  149. operator.getitem,
  150. ]
  151. tuple_strategy_ops = [aten._fused_adam.default]
  152. dp_strategy_map: Dict[fx.Node, StrategyType] = {}
  153. batch_dim_analyzer = BatchDimAnalyzer(batch_dim)
  154. placeholder_idx = 0
  155. num_param_grad = 0
  156. # first we backward propagate to mark the param gradients sharding
  157. # with tag_grad node helps and then delete the tag_grad nodes
  158. for node in reversed(list(train_step_graph.graph.nodes)):
  159. # find a param_grad node via the tagging
  160. if node.target == torch.ops._spmd.tag_grad.default:
  161. cur_node = node
  162. while cur_node.target in non_compute_ops:
  163. cur_node = cur_node.args[0]
  164. partial_strategy = _gen_partial_strategy(mesh)
  165. dp_strategy_map[cur_node] = DataParallelStrategy(
  166. NodeType.GRAD, [partial_strategy]
  167. )
  168. num_param_grad += 1
  169. # remove the tag_grad node from graph
  170. node.replace_all_uses_with(node.args[0])
  171. train_step_graph.graph.erase_node(node)
  172. if num_param_grad == num_params:
  173. # early break if we have already processed all param_grads
  174. break
  175. # next we forward propagate to mark all the sharding
  176. for node in train_step_graph.graph.nodes:
  177. if node.op == "placeholder":
  178. if "val" not in node.meta:
  179. # NOTE: There're certain cases where the placeholder nodes do
  180. # not have real tensor values:
  181. # 1. optimizer states can be None sometimes, i.e. SGD with
  182. # no momentum, optimizer states populate `momentum` state
  183. # as None, the full graph we get from `compile` would have
  184. # None as the placeholder value
  185. # 2. function args might not only contain params or activations,
  186. # but also contain other non-tensor inputs, i.e. the model
  187. # and optimizer instances baked in as a placeholder, there might
  188. # also be some scalar argument which is not a tensor
  189. #
  190. # For the above cases, we create a NON_TENSOR stratgy so that we
  191. # know it's not a tensor and we don't need to shard it
  192. dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
  193. elif placeholder_idx < num_params:
  194. # during compilation there's an assumption that the first num_params
  195. # placeholders should be parameters
  196. shard_strategy = _gen_shard_strategy(mesh, 0)
  197. replica_strategy = _gen_replicate_strategy(mesh)
  198. dp_strategy_map[node] = DataParallelStrategy(
  199. NodeType.PARAM, [replica_strategy, shard_strategy]
  200. )
  201. elif placeholder_idx < activation_idx:
  202. # optimizer states follow the same strategy as
  203. # the corresponding parameters
  204. replica_strategy = _gen_replicate_strategy(mesh)
  205. shard_strategy = _gen_shard_strategy(mesh, 0)
  206. dp_strategy_map[node] = DataParallelStrategy(
  207. NodeType.STATE, [replica_strategy, shard_strategy]
  208. )
  209. else:
  210. activation_batch_dim_size = node.meta["val"].shape[batch_dim]
  211. # find the first activation node and use its batch dim size
  212. if batch_dim_analyzer.batch_dim_size == -1:
  213. batch_dim_analyzer.init_batch_dim_size(activation_batch_dim_size)
  214. batch_dim_analyzer.set_batch_dim(node, batch_dim)
  215. shard_strategy = _gen_shard_strategy(mesh, batch_dim)
  216. dp_strategy_map[node] = DataParallelStrategy(
  217. NodeType.ACT, [shard_strategy]
  218. )
  219. placeholder_idx += 1
  220. elif node.op == "call_function":
  221. # Annotate node types for the computation graph
  222. # Data Parallel node propagation logic:
  223. # param (non-compute) -> out: param
  224. # grad (non-compute before/after) -> out: grad
  225. # state -> output: state
  226. #
  227. # param + activation (param must be replicate, act be sharded) -> out: activation
  228. # param/state + grad (param/state/grad be the same spec) -> out: param/state
  229. # param + state -> out: param
  230. if node.target in non_compute_ops:
  231. # At this point, we should have removed all the `tag_grad` nodes in the graph
  232. assert node.target != torch.ops._spmd.tag_grad.default
  233. input_nodes = node.all_input_nodes
  234. assert (
  235. len(input_nodes) == 1
  236. ), f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}"
  237. arg_strategy = dp_strategy_map[input_nodes[0]]
  238. if node.target == operator.getitem:
  239. # for getitem call, just forward the strategy from the input
  240. getitem_idx = node.args[1]
  241. if isinstance(arg_strategy, TupleStrategy):
  242. # for tuple strategy, we need to get the child strategy from the tuple
  243. dp_strategy_map[node] = arg_strategy.childs[getitem_idx]
  244. else:
  245. # if it's not a tuple strategy, we just forward the arg strategy
  246. dp_strategy_map[node] = arg_strategy
  247. else:
  248. assert isinstance(arg_strategy, DataParallelStrategy)
  249. arg_node_type = arg_strategy.node_type
  250. if arg_node_type == NodeType.PARAM:
  251. replica_strategy = _gen_replicate_strategy(mesh)
  252. dp_strategy_map[node] = DataParallelStrategy(
  253. NodeType.PARAM, [replica_strategy]
  254. )
  255. elif arg_node_type == NodeType.GRAD:
  256. partial_sig = _gen_partial_strategy(mesh)
  257. dp_strategy_map[node] = DataParallelStrategy(
  258. NodeType.GRAD, [partial_sig]
  259. )
  260. elif arg_node_type == NodeType.ACT:
  261. arg_node_spec = batch_dim_analyzer.compute_act_spec(
  262. input_nodes[0], mesh
  263. )
  264. output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
  265. shard_strategy = PlacementStrategy(
  266. output_specs=output_spec, input_specs=[arg_node_spec]
  267. )
  268. dp_strategy_map[node] = DataParallelStrategy(
  269. NodeType.ACT, [shard_strategy]
  270. )
  271. else:
  272. raise RuntimeError(
  273. f"non compute op not supporting {arg_node_type}! "
  274. )
  275. # finished processing this non-compute node
  276. continue
  277. # for computatation nodes, we need to check all the inputs
  278. input_args = node.all_input_nodes
  279. input_specs = []
  280. if node in dp_strategy_map:
  281. # found a param_grad node that already have output pre-filled spec
  282. # fill in the expected input specs for the pre-filled strategy
  283. node_strategy = dp_strategy_map[node]
  284. assert isinstance(node_strategy, DataParallelStrategy)
  285. node_type = node_strategy.node_type
  286. assert node_type == NodeType.GRAD
  287. produce_param_grad_strat = node_strategy.strategies
  288. has_activation = False
  289. for arg in input_args:
  290. arg_strategy = dp_strategy_map[arg]
  291. assert isinstance(arg_strategy, DataParallelStrategy)
  292. arg_node_type = arg_strategy.node_type
  293. if arg_node_type == NodeType.ACT:
  294. # activation sharded
  295. has_activation = True
  296. act_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
  297. input_specs.append(act_spec)
  298. if has_activation:
  299. assert len(produce_param_grad_strat) == 1
  300. produce_param_grad_strat[0].input_specs = input_specs
  301. elif node.target in tuple_strategy_ops:
  302. # ops that need to build tuple strategy instead of normal strategy
  303. # This should happen rarely and only needed when we need to generate
  304. # different node strategy for multiple outputs (i.e. fused_adam op)
  305. # TODO: Currently this specializes to fused optimizer ops, but we need
  306. # to see how to generalize this strategy building logic
  307. output_strategy_len = len(node.args) - 1
  308. tuple_strategies = []
  309. for i in range(output_strategy_len):
  310. if not isinstance(node.args[i], list):
  311. raise RuntimeError(
  312. f"Expecting list as arg to build Tuple Strategy, but found type {type(node.args[i])}!"
  313. )
  314. # for list/tuple arg, use the first one to find out the node type
  315. if len(node.args[i]) > 0:
  316. arg_strategy = dp_strategy_map[node.args[i][0]]
  317. assert isinstance(arg_strategy, DataParallelStrategy)
  318. assert arg_strategy.node_type in [
  319. NodeType.PARAM,
  320. NodeType.GRAD,
  321. NodeType.STATE,
  322. ], "Expecting param/grad/state as arg to build Tuple Strategy!"
  323. replica_strategy = _gen_replicate_strategy(mesh)
  324. shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
  325. out_node_strategy: StrategyType = DataParallelStrategy(
  326. arg_strategy.node_type, [replica_strategy, shard_strategy]
  327. )
  328. tuple_strategies.append(out_node_strategy)
  329. output_tuple_strategy = TupleStrategy(tuple(tuple_strategies))
  330. dp_strategy_map[node] = output_tuple_strategy
  331. else:
  332. # NOTE: This is the common region for all regular computation ops
  333. input_node_types = [
  334. cast(DataParallelStrategy, dp_strategy_map[arg]).node_type
  335. for arg in input_args
  336. if isinstance(dp_strategy_map[arg], DataParallelStrategy)
  337. ]
  338. if NodeType.GRAD in input_node_types:
  339. # param/state + grad, build up acceptable strategy
  340. # the strategy should be the same for all the inputs/outputs
  341. # TODO: optimizer parts should follow the dtensor prop logic
  342. # to support more general cases that allows optimizer states
  343. # to have different shardings compare to the params
  344. replica_strategy = _gen_replicate_strategy(mesh)
  345. shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
  346. output_node_type = NodeType.PARAM
  347. non_grad_types = [t for t in input_node_types if t != NodeType.GRAD]
  348. output_node_type = non_grad_types[0]
  349. for non_grad_type in non_grad_types:
  350. assert (
  351. non_grad_type == output_node_type
  352. ), f"Found more than one non grad types! Expect {output_node_type} but found {non_grad_type}!"
  353. assert output_node_type in [
  354. NodeType.PARAM,
  355. NodeType.STATE,
  356. ], f"Expecting output node type to be either state or param, but found {output_node_type}!"
  357. dp_strategy_map[node] = DataParallelStrategy(
  358. output_node_type, [replica_strategy, shard_strategy]
  359. )
  360. elif NodeType.STATE in input_node_types:
  361. # either param + state or state + state
  362. replica_strategy = _gen_replicate_strategy(mesh)
  363. shard_strategy = _gen_shard_strategy(mesh, shard_dim=0)
  364. output_node_type = (
  365. NodeType.PARAM
  366. if NodeType.PARAM in input_node_types
  367. else NodeType.STATE
  368. )
  369. dp_strategy_map[node] = DataParallelStrategy(
  370. output_node_type, [replica_strategy, shard_strategy]
  371. )
  372. elif NodeType.PARAM in input_node_types:
  373. if NodeType.ACT in input_node_types:
  374. # param + activation, build up acceptable strategy
  375. # param must be replicated, activation must be sharded
  376. for arg in input_args:
  377. arg_strategy = dp_strategy_map[arg]
  378. assert isinstance(arg_strategy, DataParallelStrategy)
  379. node_type = arg_strategy.node_type
  380. if node_type == NodeType.ACT:
  381. # compute activation spec
  382. act_spec = batch_dim_analyzer.compute_act_spec(
  383. arg, mesh
  384. )
  385. input_specs.append(act_spec)
  386. elif node_type == NodeType.PARAM:
  387. # param must be replicated
  388. input_specs.append(
  389. DTensorSpec(mesh=mesh, placements=(Replicate(),))
  390. )
  391. else:
  392. raise RuntimeError(
  393. f"Expecting node with parameter and activation, but found {input_node_types}! "
  394. )
  395. # produce activation type sharding for output
  396. output_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
  397. act_strategy = PlacementStrategy(
  398. output_specs=output_spec, input_specs=input_specs
  399. )
  400. dp_strategy_map[node] = DataParallelStrategy(
  401. NodeType.ACT, [act_strategy]
  402. )
  403. else:
  404. # If inputs only have parameters, the
  405. # strategy of this node should follow input
  406. dp_strategy_map[node] = dp_strategy_map[input_args[0]]
  407. else:
  408. # If input nodes does not have PARAM/GRAD/STATE, then
  409. # it should be a pure activation computation, it should
  410. # produce activation output.
  411. # Activations are usually sharded unless model creates
  412. # new tensors during computation, which depend on whether
  413. # the new tensor associate with a batch dim or not, it could
  414. # be shard/replicate/partial, batch dim analyzer should tell
  415. # us the correct sharding.
  416. for arg in input_args:
  417. arg_strategy = dp_strategy_map[arg]
  418. assert isinstance(arg_strategy, DataParallelStrategy)
  419. input_spec = batch_dim_analyzer.compute_act_spec(arg, mesh)
  420. input_specs.append(input_spec)
  421. act_spec = batch_dim_analyzer.compute_act_spec(node, mesh)
  422. op_strategy = PlacementStrategy(
  423. output_specs=act_spec, input_specs=input_specs
  424. )
  425. dp_strategy_map[node] = DataParallelStrategy(
  426. NodeType.ACT, [op_strategy]
  427. )
  428. elif node.op == "output":
  429. dp_strategy_map[node] = DataParallelStrategy(NodeType.NON_TENSOR, [])
  430. else:
  431. raise RuntimeError(f"op code {node.op} not supported")
  432. return dp_strategy_map # type: ignore[return-value]
  433. def mark_data_parallel_shardings(
  434. train_step_graph: GraphModule,
  435. num_parameters: int,
  436. num_states: int,
  437. dp_strategy_map: Dict[fx.Node, StrategyType],
  438. parallel_mode: DataParallelStyle = DataParallelStyle.FULLY_SHARD,
  439. ) -> None:
  440. """Mark the sharding for the nodes in the train_step_graph."""
  441. activation_idx = num_parameters + num_states
  442. placeholder_idx = 0
  443. for node in train_step_graph.graph.nodes:
  444. node_strategy = dp_strategy_map[node]
  445. if node.op == "placeholder":
  446. assert isinstance(node_strategy, DataParallelStrategy)
  447. node_type = node_strategy.node_type
  448. node_strategies = node_strategy.strategies
  449. if node_type == NodeType.NON_TENSOR:
  450. # set node sharding to None
  451. node_sharding = None
  452. elif placeholder_idx < activation_idx:
  453. assert len(node_strategies) > 0, "node_strategies should not be empty"
  454. if parallel_mode == DataParallelStyle.REPLICATE:
  455. # set to replicate for replicate style
  456. node_sharding = node_strategies[0]
  457. elif parallel_mode == DataParallelStyle.FULLY_SHARD:
  458. # set to shard for fully shard style
  459. if len(node_strategies) == 1:
  460. # only one strategy, use that instead
  461. # i.e. optimizer state steps can only be replicate
  462. node_sharding = node_strategies[0]
  463. else:
  464. # use the full sharding strategy
  465. node_sharding = node_strategies[1]
  466. elif parallel_mode == DataParallelStyle.DEFAULT:
  467. # TODO: add support for default mode
  468. # default mode would generate either replicate or shard
  469. raise NotImplementedError("default mode not implemented")
  470. else:
  471. assert len(node_strategies) > 0, "node_strategies should not be empty"
  472. # mark activation as sharded on batch dim
  473. node_sharding = node_strategies[0]
  474. node.meta["sharding"] = node_sharding # type: ignore[possibly-undefined]
  475. placeholder_idx += 1
  476. elif node.op == "call_function":
  477. if isinstance(node_strategy, TupleStrategy):
  478. # For tuple strategy in the data parallel mode, it should have the same strategy
  479. # for all tuple elements, assert that then use the first element's strategy as sharding
  480. first_strategy = cast(DataParallelStrategy, node_strategy.childs[0])
  481. for child_strategy in node_strategy.childs:
  482. assert isinstance(child_strategy, DataParallelStrategy)
  483. assert child_strategy.strategies == first_strategy.strategies
  484. node_strategies = first_strategy.strategies
  485. else:
  486. assert isinstance(node_strategy, DataParallelStrategy)
  487. node_strategies = node_strategy.strategies
  488. assert (
  489. len(node_strategies) <= 2
  490. ), "data parallel should have at most 2 strategies"
  491. if len(node_strategies) == 1:
  492. node.meta["sharding"] = node_strategies[0]
  493. elif len(node_strategies) == 2:
  494. if parallel_mode == DataParallelStyle.REPLICATE:
  495. # set to replicate for replicate style
  496. node.meta["sharding"] = node_strategies[0]
  497. elif parallel_mode == DataParallelStyle.FULLY_SHARD:
  498. # set to shard for fully shard style
  499. node.meta["sharding"] = node_strategies[1]
  500. else:
  501. raise RuntimeError("default mode not supported yet!")
  502. else:
  503. raise RuntimeError(
  504. f"node {node} strategy length {len(node_strategies)} is not expected!"
  505. )
  506. elif node.op == "output":
  507. assert (
  508. isinstance(node_strategy, DataParallelStrategy)
  509. and node_strategy.node_type == NodeType.NON_TENSOR
  510. ), "output node should not be tensor"
  511. node.meta["sharding"] = None
  512. else:
  513. raise RuntimeError(f"op code {node.op} not supported")
  514. def _partition_val(val: Any, spec: DTensorSpec) -> Any:
  515. """Util function to convert a full tensor val to its local component."""
  516. if isinstance(val, torch.Tensor):
  517. local_shard = val
  518. if val.ndim == 0:
  519. # If it's already a scalar tensor, it is already local, we don't
  520. # need to do anything
  521. return local_shard
  522. for idx, placement in enumerate(spec.placements):
  523. if placement.is_shard():
  524. placement = cast(Shard, placement)
  525. num_chunks = spec.mesh.size(mesh_dim=idx)
  526. my_coord = spec.mesh.get_coordinate()
  527. assert my_coord is not None, "current rank not in mesh!"
  528. my_coord_on_mesh_dim = my_coord[idx]
  529. local_shard = placement._split_tensor(
  530. local_shard, num_chunks, with_padding=False, contiguous=False
  531. )[0][my_coord_on_mesh_dim]
  532. return local_shard
  533. elif isinstance(val, (tuple, list)):
  534. return val.__class__(_partition_val(v, spec) for v in val)
  535. else:
  536. raise RuntimeError(f"val type {type(val)} not supported")
  537. def partitioner(graph: GraphModule) -> GraphModule:
  538. """Graph partitioner that partitions the single device graph to distributed graph."""
  539. shape_adjustment_ops = {
  540. aten._unsafe_view.default: 1,
  541. aten.expand.default: 1,
  542. aten.new_zeros.default: 1,
  543. aten.ones.default: 0,
  544. aten.reshape.default: 1,
  545. aten.view.default: 1,
  546. aten.zeros.default: 0,
  547. }
  548. # partition the graph to distributed
  549. for node in graph.graph.nodes:
  550. node_sharding = node.meta["sharding"]
  551. # None sharding means this node don't need sharding
  552. if node_sharding is None:
  553. continue
  554. if node.op == "placeholder":
  555. out_spec = node_sharding.output_spec
  556. if not hasattr(out_spec, "from_local"):
  557. local_val = _partition_val(node.meta["val"], out_spec)
  558. # update node value
  559. node.meta["val"] = local_val
  560. elif node.op == "call_function":
  561. out_spec = node_sharding.output_spec
  562. # check if there's misaligned sharding, insert reshard if there is
  563. expected_input_specs = node_sharding.input_specs
  564. for idx, input_arg in enumerate(node.all_input_nodes):
  565. input_arg_sharding = input_arg.meta["sharding"]
  566. input_arg_spec = input_arg_sharding.output_spec
  567. desired_spec = (
  568. out_spec
  569. if expected_input_specs is None
  570. else expected_input_specs[idx]
  571. )
  572. if input_arg_spec != desired_spec:
  573. input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"]
  574. desired_spec.tensor_meta = input_arg.meta["tensor_meta"]
  575. input_arg_tensor = input_arg.meta["val"]
  576. # insert reshard operation
  577. def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor:
  578. return redistribute_local_tensor(
  579. local_tensor,
  580. input_arg_spec,
  581. desired_spec,
  582. )
  583. reshard_gm = make_fx(reshard_fn)(input_arg_tensor)
  584. reshard_gm_nodes = list(reshard_gm.graph.nodes)
  585. input_node = reshard_gm_nodes[0]
  586. with graph.graph.inserting_before(node):
  587. output_node = graph.graph.graph_copy(
  588. reshard_gm.graph,
  589. val_map={
  590. input_node: input_arg,
  591. },
  592. )
  593. node.replace_input_with(input_arg, output_node)
  594. output_val = node.meta["val"]
  595. if node.target == torch.ops.aten.repeat.default:
  596. # for repeat op, we need to infer the repeat sizes
  597. assert isinstance(output_val, torch.Tensor)
  598. local_shape = compute_local_shape(
  599. output_val.shape, out_spec.mesh, out_spec.placements
  600. )
  601. input_shape = node.args[0].meta["val"].shape
  602. def infer_repeat_sizes(repeated_shape, input_shape):
  603. repeated_size = [1] * len(repeated_shape)
  604. padded_length = len(repeated_shape) - len(input_shape)
  605. for i in range(len(repeated_shape)):
  606. if i < padded_length:
  607. repeated_size[i] = repeated_shape[i]
  608. else:
  609. repeated_size[i] = (
  610. repeated_shape[i] // input_shape[i - padded_length]
  611. )
  612. return repeated_size
  613. node.update_arg(1, infer_repeat_sizes(local_shape, input_shape))
  614. elif node.target in shape_adjustment_ops:
  615. # for view related op that needs shape, adjust shape to local shape if needed
  616. assert isinstance(output_val, torch.Tensor)
  617. local_shape = compute_local_shape(
  618. output_val.shape, out_spec.mesh, out_spec.placements
  619. )
  620. shape_arg_num = shape_adjustment_ops[node.target]
  621. node.update_arg(shape_arg_num, local_shape)
  622. # convert output val to its local component
  623. node.meta["val"] = _partition_val(output_val, out_spec)
  624. elif node.op == "output":
  625. break
  626. else:
  627. raise RuntimeError(f"op code {node} not supported")
  628. # clean up the graph by removing sharding and partitioning related metadata
  629. for node in graph.graph.nodes:
  630. if "sharding" in node.meta:
  631. del node.meta["sharding"]
  632. if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
  633. local_tensor_meta = _extract_tensor_metadata(node.meta["val"])
  634. node.meta["tensor_meta"] = local_tensor_meta
  635. graph.graph.lint()
  636. graph.recompile()
  637. return graph
  638. def partition_data_parallel(
  639. graph: GraphModule,
  640. model: nn.Module,
  641. optimizer: Optional[torch.optim.Optimizer],
  642. params_buffers: Dict[str, torch.Tensor],
  643. named_states: Dict[str, Any],
  644. args: Tuple[Any, ...],
  645. kwargs: Dict[str, Any],
  646. mesh: DeviceMesh,
  647. parallel_style: DataParallelStyle,
  648. input_batch_dim: int,
  649. ) -> GraphModule:
  650. """Partition the graph to into a data parallel graph.
  651. This function also shards/replicates the model parameters and optimizer states to DTensors.
  652. """
  653. num_params_buffers = len(params_buffers)
  654. flattened_states = pytree.tree_leaves(named_states)
  655. num_states = len(flattened_states)
  656. changed = graph.graph.eliminate_dead_code()
  657. if changed:
  658. graph.recompile()
  659. # 1. First build up data parallel strategies for the whole graph
  660. strategy_map = build_data_parallel_strategies(
  661. graph, num_params_buffers, num_states, mesh=mesh, batch_dim=input_batch_dim
  662. )
  663. # 2. Next we mark the data parallel strategy for each node base on
  664. # the parallel_style
  665. mark_data_parallel_shardings(
  666. graph,
  667. num_parameters=num_params_buffers,
  668. num_states=num_states,
  669. dp_strategy_map=strategy_map,
  670. parallel_mode=parallel_style,
  671. )
  672. # 3. Partition the single machine graph to the distribute graph
  673. partitioned_graph = partitioner(graph)
  674. # preserve node types for the expanded graph
  675. for node in partitioned_graph.graph.nodes:
  676. if node in strategy_map:
  677. node_strategy = strategy_map[node]
  678. if isinstance(node_strategy, DataParallelStrategy):
  679. node.meta["node_type"] = node_strategy.node_type
  680. elif isinstance(node_strategy, TupleStrategy):
  681. node.meta["node_type"] = NodeType.NON_TENSOR
  682. else:
  683. raise RuntimeError(f"Unknown node strategy {node_strategy}")
  684. else:
  685. # if the nodes are expanded nodes (collectives), we mark them
  686. # the same type as the input node.
  687. input_node = node.all_input_nodes[0]
  688. node.meta["node_type"] = input_node.meta["node_type"]
  689. # 4. Last, inplace partition the weights and optim states to
  690. # DTensors base on the parallel style
  691. accessor = NamedMemberAccessor(model)
  692. for param_key, param in params_buffers.items():
  693. placement: Placement = Replicate()
  694. if parallel_style == DataParallelStyle.FULLY_SHARD:
  695. placement = Shard(0)
  696. elif parallel_style != DataParallelStyle.REPLICATE:
  697. raise RuntimeError(f"parallel style {parallel_style} not supported yet")
  698. dtensor_param = distribute_tensor(param, mesh, [placement])
  699. # update re-parameterized module param dict and optim states dict to DTensor
  700. params_buffers[param_key] = dtensor_param.to_local()
  701. # update module parameters to DTensor
  702. accessor.set_tensor(param_key, dtensor_param)
  703. # update the optimizer state key and values to DTensor
  704. if optimizer is not None and param in optimizer.state:
  705. param_states = named_states[param_key]
  706. param_dtensor_states = {}
  707. for state_key, state_val in param_states.items():
  708. if isinstance(state_val, torch.Tensor) and state_val.ndim > 0:
  709. # shard/replicate non-scalar tensors, for scalar tensor, we
  710. # don't do anything
  711. dtensor_state = distribute_tensor(state_val, mesh, [placement])
  712. param_dtensor_states[state_key] = dtensor_state
  713. param_states[state_key] = dtensor_state.to_local()
  714. else:
  715. param_dtensor_states[state_key] = state_val
  716. optimizer.state.pop(param) # type: ignore[call-overload]
  717. optimizer.state[dtensor_param] = param_dtensor_states # type: ignore[index]
  718. return partitioned_graph