| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- from abc import ABC, abstractmethod
- from typing import Optional, Union, Tuple, Dict, Any
- from functools import partial
- import torch
- import torch.nn as nn
- from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Replicate, Shard, distribute_tensor, distribute_module
- __all__ = [
- "ParallelStyle",
- "RowwiseParallel",
- "SequenceParallel",
- "ColwiseParallel",
- "PrepareModuleInput",
- "PrepareModuleOutput",
- ]
- class ParallelStyle(ABC):
- """
- The parallel style contract defines how the module or submodule should be parallelized.
- It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum
- flexibility for different kind of style implementations.
- """
- @abstractmethod
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- ...
- class ColwiseParallel(ParallelStyle):
- """
- Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding.
- Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules.
- (i.e. MLP, Attention)
- Keyword Args:
- input_layouts (Placement, optional):
- The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
- become a DTensor. If not specified, we assume the input tensor to be replicated.
- output_layouts (Placement, optional):
- The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
- with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
- use_local_output (bool, optional):
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
- Returns:
- A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>> ...
- >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
- >>> tp_mesh = init_device_mesh("cuda", (8,))
- >>>
- >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor
- >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim.
- >>>
- >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()})
- >>> ...
- .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
- specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
- keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
- """
- def __init__(
- self,
- *,
- input_layouts: Optional[Placement] = None,
- output_layouts: Optional[Placement] = None,
- use_local_output: bool = True
- ):
- super().__init__()
- self.input_layouts = (input_layouts or Replicate(), )
- self.output_layouts = (output_layouts or Shard(-1), )
- # colwise linear runtime sharding (desired sharding):
- # 1. requires replicate input
- # 2. shard output on last dim
- self.desired_input_layouts = (Replicate(), )
- self.use_local_output = use_local_output
- @staticmethod
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
- # TODO: figure out dynamo support for instance method and switch this to instance method
- # annotate module input placements/sharding with input_layouts
- input_tensor = inputs[0]
- if not isinstance(input_tensor, DTensor):
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
- # transform the input layouts to the desired layouts of ColwiseParallel
- if input_layouts != desired_input_layouts:
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
- return input_tensor
- def _partition_linear_fn(self, name, module, device_mesh):
- # colwise shard weight/bias to Shard(0), weight be Shard(0)
- # means Colwise as Linear is input * weight^T + bias, where
- # weight would become Shard(1)
- for name, param in module.named_parameters():
- dist_param = nn.Parameter(
- distribute_tensor(param, device_mesh, [Shard(0)])
- )
- module.register_parameter(name, dist_param)
- def _partition_embedding_fn(self, name, module, device_mesh):
- # colwise shard embedding.weight is straight forward as Shard(1)
- for name, param in module.named_parameters():
- dist_param = nn.Parameter(
- distribute_tensor(param, device_mesh, [Shard(1)])
- )
- module.register_parameter(name, dist_param)
- @staticmethod
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
- # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
- if outputs.placements != output_layouts:
- outputs = outputs.redistribute(placements=output_layouts, async_op=True)
- # back to local tensor
- return outputs.to_local() if use_local_output else outputs
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- if isinstance(module, nn.Linear):
- partition_fn = self._partition_linear_fn
- elif isinstance(module, nn.Embedding):
- partition_fn = self._partition_embedding_fn
- else:
- raise NotImplementedError("ColwiseParallel currently only support nn.Linear and nn.Embedding!")
- return distribute_module(
- module,
- device_mesh,
- partition_fn,
- partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
- partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
- )
- class RowwiseParallel(ParallelStyle):
- """
- Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
- Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
- (i.e. MLP, Attention)
- Keyword Args:
- input_layouts (Placement, optional):
- The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
- become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
- output_layouts (Placement, optional):
- The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
- with the user desired layout. If not specified, the output tensor is replicated.
- use_local_output (bool, optional):
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
- Returns:
- A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>> ...
- >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule
- >>> tp_mesh = init_device_mesh("cuda", (8,))
- >>>
- >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim
- >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`.
- >>>
- >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}),
- >>> ...
- """
- def __init__(
- self,
- *,
- input_layouts: Optional[Placement] = None,
- output_layouts: Optional[Placement] = None,
- use_local_output: bool = True
- ):
- super().__init__()
- self.input_layouts = (input_layouts or Shard(-1), )
- self.output_layouts = (output_layouts or Replicate(), )
- self.use_local_output = use_local_output
- @staticmethod
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
- input_tensor = inputs[0]
- if not isinstance(input_tensor, DTensor):
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
- if input_layouts != desired_input_layouts:
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
- return input_tensor
- def _partition_linear_fn(self, name, module, device_mesh):
- # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
- # means Rowwise as nn.Linear is input * weight^T + bias, where
- # weight would become Shard(0)
- module.register_parameter("weight", nn.Parameter(
- distribute_tensor(module.weight, device_mesh, [Shard(1)])
- ))
- if module.bias is not None:
- module.register_parameter("bias", nn.Parameter(
- distribute_tensor(module.bias, device_mesh, [Replicate()])
- ))
- def _partition_embedding_fn(self, name, module, device_mesh):
- # rowwise shard embedding.weight is Shard(0)
- for name, param in module.named_parameters():
- dist_param = nn.Parameter(
- distribute_tensor(param, device_mesh, [Shard(0)])
- )
- module.register_parameter(name, dist_param)
- @staticmethod
- def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
- # Rowwise sharding produces partial output, depending on output layouts:
- # 1. to replicate -> allreduce
- # 2. to shard -> reduce_scatter
- if outputs.placements != output_layouts:
- outputs = outputs.redistribute(placements=output_layouts, async_op=True)
- # back to local tensor if use_local_output is True
- return outputs.to_local() if use_local_output else outputs
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- if isinstance(module, nn.Linear):
- partition_fn = self._partition_linear_fn
- # rowwise linear runtime sharding requires input tensor shard on last dim
- self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1), )
- elif isinstance(module, nn.Embedding):
- partition_fn = self._partition_embedding_fn
- # rowwise embedding runtime sharding requires input tensor replicated
- self.desired_input_layouts = (Replicate(), )
- else:
- raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
- return distribute_module(
- module,
- device_mesh,
- partition_fn,
- partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
- partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
- )
- class SequenceParallel(ParallelStyle):
- """
- SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
- input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
- `RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
- This style implements the operation that is described in the paper
- `Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
- Both the input and output of the ``nn.Module`` will be sharded on the sequence dimension.
- Keyword Args:
- sequence_dim (int, optional):
- The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
- become a DTensor that is sharded on the sequence dimension, default: 1.
- use_local_output (bool, optional):
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
- Returns:
- A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>> ...
- >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
- >>> tp_mesh = init_device_mesh("cuda", (8,))
- >>>
- >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
- >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
- >>>
- >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
- >>> ...
- .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
- ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
- inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
- to ensure that they are replicated.
- """
- def __init__(
- self,
- *,
- sequence_dim: int = 1,
- use_local_output: bool = False
- ):
- super().__init__()
- self.sequence_dim = sequence_dim
- self.use_local_output = use_local_output
- def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
- for p_name, param in module.named_parameters():
- # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow
- # us to simply just use from_local
- replicated_param = torch.nn.Parameter(
- DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
- )
- module.register_parameter(p_name, replicated_param)
- @staticmethod
- def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh):
- input_tensor = inputs[0]
- if isinstance(input_tensor, DTensor):
- return inputs
- elif isinstance(input_tensor, torch.Tensor):
- return DTensor.from_local(input_tensor, device_mesh, [Shard(sequence_dim)], run_check=False)
- else:
- raise ValueError(f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}")
- @staticmethod
- def _prepare_output_fn(use_local_output, mod, outputs, device_mesh):
- return outputs.to_local() if use_local_output else outputs
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- return distribute_module(
- module,
- device_mesh,
- self._replicate_module_fn,
- partial(self._prepare_input_fn, self.sequence_dim),
- partial(self._prepare_output_fn, self.use_local_output),
- )
- class PrepareModuleInput(ParallelStyle):
- """
- Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
- ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
- Keyword Args:
- input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
- The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
- DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
- as a placeholder. default: None.
- desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]):
- The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
- have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None.
- input_kwarg_layouts (Dict[str, Placement]):
- The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors.
- default: None
- desired_input_kwarg_layouts: (Dict[str, Placement]):
- The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module
- have the desired DTensor layouts. default: None.
- use_local_output (bool, optional):
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
- Returns:
- A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>> ...
- >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
- >>> tp_mesh = init_device_mesh("cuda", (8,))
- >>>
- >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor
- >>> # and then redistributed to Replicated DTensor.
- >>> parallelize_module(
- >>> block, # this can be a submodule or module
- >>> tp_mesh,
- >>> parallelize_plan={
- >>> "attn": PrepareModuleInput(
- >>> input_layouts=(Shard(0), None, None, ...),
- >>> desired_input_layouts=(Replicate(), None, None, ...)
- >>> ),
- >>> }
- >>> )
- """
- def __init__(
- self,
- *,
- input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
- desired_input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None,
- input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
- desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
- use_local_output: bool = False
- ):
- self.input_layouts = (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
- self.desired_input_layouts = \
- (desired_input_layouts,) if isinstance(desired_input_layouts, Placement) else desired_input_layouts
- self.use_local_output = use_local_output
- if self.input_layouts is not None:
- assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
- assert len(self.input_layouts) == len(self.desired_input_layouts), \
- "input_layouts and desired_input_layouts should have same length!"
- self.with_kwargs = input_kwarg_layouts is not None
- self.input_kwarg_layouts = input_kwarg_layouts or {}
- self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
- if self.with_kwargs:
- assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \
- "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
- def _prepare_input_arg(
- self,
- input: Any,
- mesh: DeviceMesh,
- input_layout: Optional[Placement],
- desired_layout: Optional[Placement]
- ):
- if input_layout is not None:
- if isinstance(input, DTensor):
- # TODO: re-enable the check once we fix the compile path
- # assert inp.placements[0] == input_layout
- dt_inp = input
- else:
- assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!"
- dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False)
- if desired_layout is not None and input_layout != desired_layout:
- dt_inp = dt_inp.redistribute(placements=(desired_layout,))
- return dt_inp.to_local() if self.use_local_output else dt_inp
- else:
- return input
- def _prepare_input_fn(self, inputs, device_mesh):
- if self.input_layouts is None:
- return inputs
- prepared_inputs = []
- if not isinstance(inputs, tuple):
- inputs = (inputs,)
- if len(inputs) != len(self.input_layouts):
- raise ValueError("module inputs and input_layouts should have same length!")
- assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
- for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
- prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout))
- return tuple(prepared_inputs)
- def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
- prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
- prepared_kwarg_inputs = {}
- for kwarg_key in kwarg_inputs.keys():
- kwarg_val = kwarg_inputs[kwarg_key]
- input_layout = self.input_kwarg_layouts.get(kwarg_key)
- desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
- prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout)
- return (prepared_arg_inputs, prepared_kwarg_inputs)
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- if self.with_kwargs:
- module.register_forward_pre_hook(
- lambda _, inputs, kwargs: self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh),
- with_kwargs=True
- ) # type: ignore[misc]
- else:
- module.register_forward_pre_hook(lambda _, inputs: self._prepare_input_fn(inputs, device_mesh)) # type: ignore[misc, call-arg]
- return module
- class PrepareModuleOutput(ParallelStyle):
- """
- Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
- ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
- Keyword Args:
- output_layouts (Union[Placement, Tuple[Placement]]):
- The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
- DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
- ``None`` need to be specified as a placeholder.
- desired_output_layouts (Union[Placement, Tuple[Placement]]):
- The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
- have the desired DTensor layouts.
- use_local_output (bool, optional):
- Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True.
- Returns:
- A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput
- >>> from torch.distributed.device_mesh import init_device_mesh
- >>> ...
- >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule
- >>> tp_mesh = init_device_mesh("cuda", (8,))
- >>>
- >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor
- >>> # and then redistributed to Sharded DTensor.
- >>> parallelize_module(
- >>> block, # this can be a submodule or module
- >>> tp_mesh,
- >>> parallelize_plan = PrepareModuleOutput(
- >>> output_layouts=Replicate(),
- >>> desired_output_layouts=Shard(0)
- >>> )
- >>> )
- """
- def __init__(
- self,
- *,
- output_layouts: Union[Placement, Tuple[Placement]],
- desired_output_layouts: Union[Placement, Tuple[Placement]],
- use_local_output: bool = True
- ):
- self.output_layouts = (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts
- self.desired_output_layouts = \
- (desired_output_layouts,) if isinstance(desired_output_layouts, Placement) else desired_output_layouts
- self.use_local_output = use_local_output
- assert len(self.output_layouts) == len(self.desired_output_layouts), \
- "output_layouts and desired_output_layouts should have same length!"
- def _prepare_out_fn(self, outputs, device_mesh):
- prepared_outputs = []
- if not isinstance(outputs, tuple):
- outputs = (outputs,)
- if len(outputs) != len(self.output_layouts):
- raise ValueError("module outputs and output_layouts should have same length!")
- for out, out_layout, desired_out_layout in zip(outputs, self.output_layouts, self.desired_output_layouts):
- if out_layout is not None:
- if isinstance(out, DTensor):
- # TODO: re-enable the check once we fix the compile path
- # assert out.placements[0] == out_layout
- dt_out = out
- else:
- dt_out = DTensor.from_local(out, device_mesh, (out_layout,), run_check=False)
- if out_layout != desired_out_layout:
- dt_out = dt_out.redistribute(placements=(desired_out_layout,))
- prepared_outputs.append(dt_out.to_local() if self.use_local_output else dt_out)
- else:
- prepared_outputs.append(out)
- if len(prepared_outputs) == 1:
- return prepared_outputs[0]
- else:
- return tuple(prepared_outputs)
- def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
- module.register_forward_hook(lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh)) # type: ignore[misc, call-arg]
- return module
|