| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import copy
- import logging
- import operator
- from collections import defaultdict
- from enum import Enum
- from inspect import Parameter, signature, Signature
- from types import MethodType
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
- import torch
- import torch.fx as fx
- from torch.distributed import ProcessGroup
- from torch.export import ExportedProgram
- from torch.export.unflatten import (
- _assign_attr,
- _AttrKind,
- _sink_params,
- InterpreterModule,
- )
- from torch.fx.node import map_aggregate
- from torch.fx.passes.split_module import split_module
- from ._backward import _null_coalesce_accumulate, stage_backward
- from ._unflatten import _outline_submodules
- from ._utils import PipeInfo
- from .stage import _PipelineStage
- logger = logging.getLogger(__name__)
- # TODO:
- # 1. investigate gradient sync for shared parameters. how does DDP do it?
- # 2. Add parameter movement to split_module
- def _find_loss_from_output_and_spec(output_val, spec_val):
- if spec_val is False:
- return None
- if spec_val is True:
- if not isinstance(output_val, fx.Node):
- raise RuntimeError(
- f"Loss spec must specify a dynamic value but got {output_val}"
- )
- return output_val
- if isinstance(spec_val, (tuple, list)):
- if not isinstance(output_val, (tuple, list)):
- raise RuntimeError(
- f"Output value {output_val} must match type of loss specification "
- f"{spec_val}"
- )
- if len(output_val) != len(spec_val):
- raise RuntimeError(
- f"Output value {output_val} must match length of loss specification "
- f"{spec_val}"
- )
- for out, spec in zip(output_val, spec_val):
- loss_val = _find_loss_from_output_and_spec(out, spec)
- if loss_val is not None:
- return loss_val
- raise RuntimeError(f"Did not find loss value in specification {spec_val}")
- if isinstance(spec_val, dict):
- if not isinstance(output_val, dict):
- raise RuntimeError(
- f"Output value {output_val} must match type of loss specification "
- f"{spec_val}"
- )
- if set(output_val.keys()) != set(spec_val.keys()):
- raise RuntimeError(
- f"Output value {output_val} must match keys of loss specification "
- f"{spec_val}"
- )
- for k in spec_val:
- loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
- if loss_val is not None:
- return loss_val
- raise RuntimeError(f"Did not find loss value in specification {spec_val}")
- raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
- def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
- output_nodes = [n for n in g.nodes if n.op == "output"]
- assert len(output_nodes) == 1
- output_node = output_nodes[0]
- output_val = output_node.args[0]
- generated_spec: Any = None
- if isinstance(mod, TrivialLossWrapper):
- # TrivialLossWrapper is pre-defined by PiPPy.
- # It has loss as the only output so we can safely assume the first output arg is the loss.
- assert len(output_node.args) == 1
- loss_node = output_val
- generated_spec = TrivialLossWrapper.loss_spec
- elif output_loss_value_spec is None:
- # Use default spec, i.e. search for "loss" in output values
- if isinstance(output_val, dict) and "loss" in output_val.keys():
- loss_node = output_val["loss"]
- generated_spec = {k: k == "loss" for k in output_val}
- else:
- loss_node = None
- generated_spec = None
- else:
- loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
- generated_spec = output_loss_value_spec
- return loss_node, output_node, generated_spec
- def _insert_stage_symbolic_backward(
- g: fx.Graph,
- loss_node: fx.Node,
- output_node: fx.Node,
- ):
- # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
- tuples: Dict[fx.Node, Tuple] = {}
- for node in reversed(g.nodes):
- if node.op == "call_function":
- # In the forward pass, only emit placeholder, module calls, and
- # getitem calls. If we have a target other than getitem in this
- # (forward-only) code, there is a bug.
- assert node.target == operator.getitem, (
- "Found non-getitem call in forward pass. "
- "Please report a bug to PiPPy"
- )
- assert (
- len(node.args) == 2
- ), "Found malformed getitem call. Please report a bug to PiPPy"
- indexed_value, node_idx = tuple(node.args)
- # indexed_value is a collection that we are indexing into. It could
- # exist in the tuples map if we've processed another `getitem`
- # already.
- existing_list_size = (
- len(tuples[indexed_value]) if indexed_value in tuples else -1
- )
- new_list_size = max(node_idx + 1, existing_list_size)
- reconstructed_list = [None for _ in range(new_list_size)]
- # Copy over existing elements if present
- if indexed_value in tuples:
- for i, val in enumerate(tuples[indexed_value]):
- reconstructed_list[i] = val
- # Populate value represented by this node
- reconstructed_list[node_idx] = node
- tuples[indexed_value] = tuple(reconstructed_list)
- # Keep track of nodes that dominate the loss node.
- # We will only emit backward operations for nodes that can contribute
- # to the specified loss value.
- live_nodes = {loss_node: None}
- val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
- def assign_or_accumulate_grad(forward_node, grad_value):
- if forward_node in val_to_grad and forward_node.op != "placeholder":
- grad_value = g.call_function(
- _null_coalesce_accumulate,
- (val_to_grad[forward_node], grad_value),
- )
- val_to_grad[forward_node] = grad_value
- with g.inserting_before(output_node):
- for node in reversed(g.nodes):
- if node not in live_nodes:
- continue
- def add_to_live_nodes(n):
- live_nodes.setdefault(n, None)
- fx.node.map_arg(node.args, add_to_live_nodes)
- fx.node.map_arg(node.kwargs, add_to_live_nodes)
- if node.op == "call_module":
- output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
- if node in tuples:
- stage_output = tuples[node]
- output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
- outputs_with_grads_idxs = [
- i for i, n in enumerate(tuples[node]) if n in live_nodes
- ]
- else:
- stage_output = (node,)
- output_grads = val_to_grad[node]
- outputs_with_grads_idxs = [0]
- output_grads = (
- (output_grads,)
- if not isinstance(output_grads, tuple)
- else output_grads
- )
- grad_call = g.call_function(
- stage_backward,
- kwargs={
- "stage_output": stage_output,
- "output_grads": output_grads,
- "input_values": list(node.all_input_nodes),
- "outputs_with_grads_idxs": outputs_with_grads_idxs,
- },
- )
- # Insert backward stage debug info
- kwargs_copy = dict(grad_call.kwargs)
- grad_call.kwargs = kwargs_copy
- grad_call_proxy = fx.Proxy(grad_call)
- grads = grad_call_proxy.node
- input_nodes = list(node.all_input_nodes)
- grads_proxy = fx.Proxy(grads)
- for i, input_node in enumerate(input_nodes):
- assign_or_accumulate_grad(input_node, grads_proxy[i].node)
- return g
- class PipeSequential(torch.nn.Sequential):
- @staticmethod
- def from_sequential(sequential_instance: torch.nn.Sequential):
- return PipeSequential(*[copy.copy(m) for m in sequential_instance])
- def forward(self, input):
- for i, module in enumerate(self):
- input = module(input)
- if i != len(self) - 1:
- pipe_split()
- return input
- class LossWrapper(torch.nn.Module):
- """
- LossWrapper is a convenient abstract class that allows you to wrap up both
- your model as well as its loss function and specify the connectivity between
- the inputs, model, loss function, and output value. Example::
- class MyModelWrapper(LossWrapper):
- def forward(self, x, targets):
- model_out = self.module(x)
- loss_value = self.loss_fn(model_out, targets)
- return loss_value
- The above example defines a connectivity where we expect the forward/loss/backward
- training procedure to take two arguments (x and targets), pass x into the module
- to get the output of the feedforward computation, pass the model output and the
- targets value into the loss function, and get and return the loss value, which will
- be backpropagated by PiPPy. The above class would then be instantiated like::
- model = ... # instantiate the model
- loss_fn = torch.nn.MSELoss() # for the sake of demonstration
- wrapper = MyModelWrapper(model, loss_fn)
- pipe = Pipe.from_tracing(wrapper, ...)
- """
- def __init__(self, module, loss_fn):
- super().__init__()
- self.module = module
- self.loss_fn = loss_fn
- def forward(self, *args, **kwargs):
- raise NotImplementedError(
- "This instance of LossWrapper does not have an overridden"
- "forward(). Please implement forward() to specify the arguments, "
- "connection between the module and loss, and loss output "
- "value."
- )
- class TrivialLossWrapper(LossWrapper):
- def forward(self, x, targets):
- model_out = self.module(x)
- return self.loss_fn(model_out, targets)
- loss_spec = True
- # Pipe model representation
- #
- # Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
- # a single topological ordering of pipeline "stages" that, when run in series,
- # constitutes all of the operations of the program. However, unlike `nn.Sequential`,
- # Pipe allows non-local usages of values, so long as those uses still respect
- # topological ordering. In particular:
- #
- # 1. Non-local activations. This type of usage can appear in, for example, skip
- # connections. These values will be directly transmitted from the "def" stage
- # to all stages that use them skipping intermediate stages. During autograd,
- # gradients will be propagated back through this skip connection reverse
- # to how activations propagated in the forward pass.
- # 2. Non-local parameter/module invocations. This occurs when a parameter is used
- # in a stage downstream of where it is resident. These values can be carried
- # forward similarly to (1), but in addition one might want to replicate the
- # value on multiple stages. Gradients for these shared parameters will be
- # accumulated separately on each stage, but there will be an additional
- # gradient accumulation before the optimizer step.
- # Register `_pipe_split()` as an ATen operator. This is required for Export to
- # preserve this marker in the graph.
- torch.library.define("pippy::_pipe_split", "() -> ()")
- @torch.library.impl("pippy::_pipe_split", "BackendSelect")
- def _pipe_split():
- return None
- @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
- def _pipe_split(): # noqa: F811
- return None
- # Add an alias for convenience
- aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
- # Ask Export to preserve the `_pipe_split` op.
- # See examples in pytorch/torch/fx/node.py
- fx.node._side_effectful_functions.add(aten_pipe_split_alias)
- # User facing API
- def pipe_split():
- """
- pipe_split is a special operator that is used to mark the boundary between
- stages in a module. It is used to split the module into stages. It is a
- no-op if your annotated module is run eagerly.
- Example:
- >>> # xdoctest: +SKIP
- >>> def forward(self, x):
- >>> x = torch.mm(x, self.mm_param)
- >>> x = torch.relu(x)
- >>> pipe_split()
- >>> x = self.lin(x)
- >>> return x
- The above example will be split into two stages.
- """
- return torch.ops.pippy._pipe_split()
- class MultiUseParameterConfig(Enum):
- TRANSMIT = 1
- REPLICATE = 2
- MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
- class DetachExecutor(fx.Interpreter):
- """
- Special interpreter to run the split_gm in testing that detaches all inputs to
- a module invocation. This is needed so that the values at the boundary are
- leaf modules in autograd execution.
- """
- def __init__(self, module, garbage_collect_values=True):
- garbage_collect_values = False
- super().__init__(module, garbage_collect_values)
- self.value_remap = {}
- def run(self, *args, initial_env=None):
- self.value_remap = {}
- return super().run(*args, initial_env=initial_env)
- def call_module(self, target, args, kwargs):
- def detach_tensors(a):
- if isinstance(a, torch.Tensor) and a.requires_grad:
- if a not in self.value_remap:
- new_val = a.detach().requires_grad_(True)
- self.value_remap[a] = new_val
- return self.value_remap[a]
- else:
- return a
- """
- def dont_traverse_size(a):
- return type(a) != torch.Size
- """
- args = map_aggregate(
- args,
- detach_tensors, # dont_traverse_size
- )
- kwargs = map_aggregate(
- kwargs,
- detach_tensors, # dont_traverse_size
- )
- return super().call_module(target, args, kwargs)
- def call_function(self, target, args, kwargs):
- # HACK to reroute saved input tensors to point to the detach()ed version
- if target == stage_backward:
- kwargs = dict(kwargs)
- kwargs["input_values"] = [
- self.value_remap.get(v, v) for v in kwargs["input_values"]
- ]
- return super().call_function(target, args, kwargs)
- class _NodeReference:
- def __init__(self, name):
- self.name = name
- name: str
- class _LinearNodeList:
- def __init__(self, node_list):
- self.serialize_node_list = []
- for node in node_list:
- node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
- node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
- serialize_node = fx.Node(
- graph=None,
- name=node.name,
- op=node.op,
- target=node.target,
- args=node_args,
- kwargs=node_kwargs,
- return_type=node.type,
- )
- serialize_node.meta = copy.copy(node.meta)
- self.serialize_node_list.append(serialize_node)
- def to_graph(self):
- graph = fx.Graph()
- ref_str_to_node: Dict[str, fx.Node] = {}
- def ref_to_node(arg):
- if isinstance(arg, _NodeReference):
- return ref_str_to_node[arg.name]
- else:
- return arg
- for node in self.serialize_node_list:
- node_args = map_aggregate(node.args, ref_to_node)
- node_kwargs = map_aggregate(node.kwargs, ref_to_node)
- deser_node = graph.create_node(
- op=node.op,
- target=node.target,
- args=node_args,
- kwargs=node_kwargs,
- name=node.name,
- type_expr=node.type,
- )
- ref_str_to_node[node.name] = deser_node
- return graph
- def _direct_serialization_deserialize(body, nodes):
- """
- Custom `__reduce__` method for serialization.
- DO AS I SAY -- NOT AS I DO. This violates the principle that
- GraphModules serialize via code export & re-tracing. We allow
- for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
- TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
- these instances to disk will expose internal implementation
- details of `fx.Graph` and related data structures and is
- NOT advised.
- """
- class DummyModule(torch.nn.Module):
- def __init__(self, body):
- super().__init__()
- self.__dict__.update(body)
- dummy = DummyModule(body)
- return fx.GraphModule(dummy, nodes.to_graph())
- def _direct_serialization_reduce(self):
- serialization_dict = dict(self.__dict__)
- serialization_dict.pop("_graph")
- return (
- _direct_serialization_deserialize,
- (serialization_dict, _LinearNodeList(self.graph.nodes)),
- )
- def _modify_graph_op_device(
- gm: torch.fx.GraphModule,
- new_device: torch.device,
- ):
- """
- Modify the device argument of all "call_function" nodes in the graph. This
- is useful for moving the graph to a different device. In particular for
- generator ops, like torch.ones.
- """
- modified = False
- for node in gm.graph.nodes:
- if node.op == "call_function":
- if "device" in node.kwargs and node.kwargs["device"] != new_device:
- logger.debug(
- f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
- )
- node.update_kwarg("device", new_device)
- modified = True
- elif node.op == "call_module":
- # Recursively modify "device" in submodules
- submod = gm.get_submodule(node.target)
- if isinstance(submod, torch.fx.GraphModule):
- _modify_graph_op_device(submod, new_device)
- elif isinstance(submod, InterpreterModule):
- # If unflattening has been performed, we need to access its graph module by `.graph_module`
- _modify_graph_op_device(submod.graph_module, new_device)
- else:
- logger.warning(
- f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
- )
- if modified:
- gm.recompile()
- class Pipe(torch.nn.Module):
- def __init__(
- self,
- split_gm: fx.GraphModule,
- num_stages: int,
- has_loss_and_backward: bool,
- loss_spec,
- ):
- # TODO: is there a way not to hard wire init?
- torch.nn.Module.__init__(self)
- self.split_gm: fx.GraphModule = split_gm
- self.executor: DetachExecutor = DetachExecutor(self.split_gm)
- self.num_stages: int = num_stages
- self.has_loss_and_backward = has_loss_and_backward
- self.loss_spec = loss_spec
- for node in split_gm.graph.nodes:
- assert (
- node.op in {"call_module", "placeholder", "output"}
- or (node.op, node.target) == ("call_function", operator.getitem)
- or (node.op, node.target) == ("call_method", "backward")
- or (node.op, node.target) == ("call_function", stage_backward)
- or (node.op, node.target)
- == ("call_function", _null_coalesce_accumulate)
- ), node
- # Detect replicated parameters so we know that we have to do an additional allreduce
- # before applying the optimizer
- #
- # Note that this also handles the case where there were multiple calls to a single
- # module from different stages, regardless of whether that module invocation
- # was handled by the logic above.
- # Map parameter value to a dictionary that maps the user pipeline module
- # to the local qualname within that module
- params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
- for m_qualname, mod in self.split_gm.named_children():
- for p_qualname, param in mod.named_parameters():
- params_to_users.setdefault(param, {})
- params_to_users[param][m_qualname] = p_qualname
- self.replicated_params: List[Dict[str, str]] = [
- use_mapping
- for _, use_mapping in params_to_users.items()
- if len(use_mapping) > 1
- ]
- # We must break the aliasing relationship between the replicated parameters for correct
- # numerics in reference runs. If we do not do this, the autograd tape in separate stages
- # will have a reference to the same tensor value and will erroneously apply gradient
- # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
- # values so that we have separate instances.
- for param_mapping in self.replicated_params:
- for submod_name, param_qualname in param_mapping.items():
- submod = getattr(self.split_gm, submod_name)
- atoms = param_qualname.split(".")
- for atom in atoms[:-1]:
- submod = getattr(submod, atom)
- setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
- def throw(self, *args, **kwargs):
- raise RuntimeError(
- "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
- )
- self.split_gm.forward = throw
- # Make submodules use custom direct-serialized GraphModule
- i = 0
- while True:
- try:
- name = f"submod_{i}"
- submod = getattr(self.split_gm, name)
- submod.__class__.__reduce__ = _direct_serialization_reduce
- i += 1
- except AttributeError:
- break
- def forward(self, *args, **kwargs):
- executor_args = args
- if len(kwargs) > 0:
- parameters = []
- for node in self.split_gm.graph.nodes:
- if node.op == "placeholder":
- if node.args and len(node.args) > 0:
- parameters.append(
- Parameter(
- node.target,
- Parameter.POSITIONAL_OR_KEYWORD,
- default=node.args[0],
- )
- )
- else:
- parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
- param_name = node.target
- if node.target.startswith("**"):
- parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
- param_name = param_name[2:]
- elif node.target.startswith("*"):
- parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
- param_name = param_name[1:]
- parameters.append(Parameter(param_name, parameter_kind))
- signature = Signature(parameters)
- ba = signature.bind(*args, **kwargs)
- ba.apply_defaults()
- executor_args = ba.arguments.values() # type: ignore[assignment]
- res = self.executor.run(*executor_args)
- return res
- def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
- """
- Return a stage module corresponding to `stage_idx` of the `pipe`.
- """
- if stage_idx < 0 or stage_idx >= self.num_stages:
- raise ValueError(f"Invalid stage index {stage_idx}!")
- return getattr(self.split_gm, f"submod_{stage_idx}")
- @staticmethod
- def _number_and_count_forward_stages(gm: fx.GraphModule):
- num_stages = 0
- found_idxs: Dict[int, None] = {}
- for node in gm.graph.nodes:
- if node.op == "call_module" and node.target.startswith("submod_"):
- node.meta["stage_idx"] = int(node.target[len("submod_") :])
- found_idxs.setdefault(node.meta["stage_idx"])
- num_stages += 1
- # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
- # Update: the following assert may fail against some torch versions >=
- # 2.2.0, as:
- # submod_0, submod_1, submod_2, ...
- # may be named as
- # submod_0, submod_2, submod_4, ...
- # TODO: investigate
- # assert all(i in found_idxs for i in range(num_stages))
- return num_stages
- @staticmethod
- def _from_traced(
- mod: torch.nn.Module,
- exported_program: ExportedProgram,
- multi_use_param_spec: Optional[MultiUseParamSpec] = None,
- output_loss_value_spec=None,
- split_policy: Optional[
- Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
- ] = None,
- ):
- """
- Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
- which value in the output of `forward` is the loss value on which PiPPy should apply
- backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
- you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
- a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
- ``output_loss_value_spec={'loss': True, 'model_out': False}``
- """
- traced = exported_program.module()
- if split_policy is not None:
- logger.info("Auto-splitting model")
- traced = split_policy(traced) # type: ignore[arg-type]
- logger.debug(traced.print_readable(print_output=False))
- # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
- # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
- # the case (especially with custom tracers), so fix that up here.
- get_attr_nodes: Dict[str, fx.Node] = {}
- for node in traced.graph.nodes:
- if node.op == "get_attr":
- get_attr_nodes.setdefault(node.target, node)
- if get_attr_nodes[node.target] != node:
- node.replace_all_uses_with(get_attr_nodes[node.target])
- traced.graph.erase_node(node)
- # avoid looking at next node by keeping track of previous pipe_split
- prev_pipe_split_idx = -1
- pipe_split_nodes_to_erase = set()
- for i, node in enumerate(traced.graph.nodes):
- if (node.op, node.target) == ("call_function", pipe_split):
- if prev_pipe_split_idx == i - 1:
- pipe_split_nodes_to_erase.add(node)
- prev_pipe_split_idx = i
- for node in pipe_split_nodes_to_erase:
- traced.graph.erase_node(node)
- traced.recompile()
- part_idx = 0
- def split_callback(n: fx.Node):
- nonlocal part_idx
- if (n.op, n.target) == (
- "call_function",
- aten_pipe_split_alias,
- ):
- logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
- part_idx += 1
- return part_idx
- # TODO: what does split do with module invocations? does it move the modules
- # into the submodules?
- split = split_module(traced, mod, split_callback)
- # a (custom) tracer can produce dead code like orphan get_attr nodes
- split.graph.eliminate_dead_code()
- # peephole to remove pipe_split
- for submodule in split.modules():
- if isinstance(submodule, fx.GraphModule):
- for node in submodule.graph.nodes:
- if (node.op, node.target) == (
- "call_function",
- aten_pipe_split_alias,
- ):
- submodule.graph.erase_node(node)
- submodule.recompile()
- for name, submodule in split.named_children():
- if isinstance(submodule, fx.GraphModule):
- new_submod = _outline_submodules(submodule.graph)
- # Replace old submod
- split.register_module(name, new_submod)
- # TODO: backport this into split_module
- def delete_user_reference(node, user):
- """
- Delete reference of `node` from `user`'s arg list.
- Args:
- - node: a `get_attr` node at root.
- - user: a submodule node that uses `node`.
- """
- assert len(user.kwargs) == 0
- use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
- assert len(use_idxs) == 1
- args_copy = list(user.args)
- args_copy.pop(use_idxs[0])
- user.args = tuple(args_copy)
- logger.debug(
- f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
- )
- # A list of param referrals for deferred deletion.
- # To be accumulated in `move_param_to_callee`.
- to_delete = list()
- def _recursive_getattr_with_parent(mod, fqn):
- # Returns getattr call given a nested FQN, and the last parent
- atoms = fqn.split(".")
- for atom in atoms[:-1]:
- if not hasattr(mod, atom):
- return None, None
- mod = getattr(mod, atom)
- if not hasattr(mod, atoms[-1]):
- return mod, None
- attr = getattr(mod, atoms[-1])
- return mod, attr
- def move_param_to_callee(
- root,
- callee_name,
- param_fqn,
- ):
- """
- Move a parameter from the root module to a submodule.
- Args:
- root: The root module.
- callee_name: The name of the submodule to move the parameter to.
- param_fqn: The fully qualified name of the parameter to move.
- """
- # `atoms` is a list of strings representing the path to the
- # parameter in the original model
- atoms = param_fqn.split(".")
- mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
- # Check whether the parameter is a buffer or a parameter
- is_buffer = atoms[-1] in mod_itr._buffers
- # Check whether the parameter is a tensor
- assert isinstance(param_val, torch.Tensor), (
- f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
- + (
- f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
- f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
- f"usages of '{param_fqn}' in the traced graph."
- if isinstance(param_val, torch.nn.Module)
- else ""
- )
- )
- # Get submodule
- callee = root.get_submodule(callee_name)
- assert not hasattr(
- callee, param_fqn
- ), f"Module {callee_name} already has a parameter named {param_fqn}"
- # Assign the parameter to the submodule
- if is_buffer:
- _assign_attr(
- param_val,
- callee,
- param_fqn,
- attr_kind=_AttrKind.BUFFER,
- persistent=True, # TODO: handle non-persistent buffer
- )
- else:
- _assign_attr(
- param_val,
- callee,
- param_fqn,
- attr_kind=_AttrKind.PARAMETER,
- )
- logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
- # Next step is to replace placeholder of submodule with a get_attr.
- # Those placeholders are created by `split_module` inside each
- # submodule.
- # Update: this step is now moved to `_sink_params` because
- # `_sink_params` can do it recursively (i.e. for modules inside
- # submodule)
- to_delete.append((mod_itr, atoms[-1]))
- # Get the list of all parameters in the root module
- attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
- for node in attr_nodes:
- # Check whether the parameter is used in only one submodule
- if len(node.users) > 1:
- logger.info(
- f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
- )
- for user in node.users:
- assert user.op == "call_module"
- # Move parameter into submodule
- move_param_to_callee(
- split,
- user.target,
- node.target,
- )
- # [aliasing] store tensor id -> list of FQNs, built from state dict
- # Also assign non-persistent buffers
- id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
- for fqn, tensor in mod.state_dict(keep_vars=True).items():
- id_to_fqns[id(tensor)].add(fqn)
- for fqn, tensor in mod.named_buffers():
- id_to_fqns[id(tensor)].add(fqn)
- # After moving the params to their corresponding hierarchies, we also
- # need to move the `get_attr` nodes from the root of the graph to those
- # hierarchies.
- # [aliasing] use id -> fqn mapping to list out all valid FQNs
- inputs_to_state: Dict[str, List[str]] = {}
- for attr in attr_nodes:
- _, tensor = _recursive_getattr_with_parent(mod, attr.target)
- fqns = list(id_to_fqns[id(tensor)])
- if fqns:
- inputs_to_state[attr.name] = fqns
- elif attr.target in exported_program.constants: # lifted constants
- inputs_to_state[attr.name] = [attr.target]
- # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
- # We determine this based on whether or not the FQN attribute parent exists.
- # i.e. if the last submodule exists, assign the attribute.
- added_attributes: Dict[str, List[str]] = defaultdict(list)
- for fqn, tensor in mod.state_dict(keep_vars=True).items():
- for name, submod in split.named_children():
- if isinstance(submod, fx.GraphModule):
- parent, child = _recursive_getattr_with_parent(submod, fqn)
- if (
- parent and child is None
- ): # parent exists, attribute doesn't -> assign
- added_attributes[name].append(fqn)
- setattr(parent, fqn.split(".")[-1], tensor)
- # Deferral deletion: Remove the original attributes (to params) from the
- # root GraphModule
- for mod_itr, last_atom in to_delete:
- try:
- delattr(mod_itr, last_atom)
- except AttributeError:
- # This is expected if the parameter is used in multiple stages
- pass
- # This is done by (1) `_sink_params` at each submodule;
- for name, submod in split.named_children():
- if isinstance(submod, fx.GraphModule):
- _sink_params(submod, inputs_to_state, [])
- submod.graph.lint()
- submod.recompile()
- # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
- # After _sink_params() routine has run, clean up unused attributes that we previously added.
- # Determine this based on the get_attr nodes - if not used, remove it.
- for name, attributes in added_attributes.items():
- submod = getattr(split, name)
- unused_attributes = set(attributes)
- # track used attributes in the submodule, running DFS on subgraph hierarchy
- stack = [("", submod)] # (scope, submodule)
- while stack:
- scope, _mod = stack.pop()
- if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
- for node in _mod.graph.nodes:
- if node.op == "get_attr":
- # get_attr might get access deeper level attribute
- fqn = scope + "." + node.target if scope else node.target
- if fqn in unused_attributes: # used, remove it
- unused_attributes.remove(fqn)
- for _name, _submod in _mod.named_children():
- stack.append((scope + "." + _name if scope else _name, _submod))
- # delete unused attributes
- for attr in unused_attributes:
- mod_itr, atoms = submod, attr.split(".")
- for atom in atoms[:-1]:
- mod_itr = getattr(mod_itr, atom)
- delattr(mod_itr, atoms[-1])
- for node in attr_nodes:
- # And (2): remove `get_attr` node from submod's arg list
- for user in copy.copy(node.users):
- assert user.op == "call_module"
- delete_user_reference(node, user)
- # And (3): remove the `get_attr` node from the root graph.
- split.graph.erase_node(node)
- split.delete_all_unused_submodules()
- split.graph.lint()
- split.recompile()
- num_stages = Pipe._number_and_count_forward_stages(split)
- has_loss_and_backward = False
- generated_loss_spec = output_loss_value_spec
- if output_loss_value_spec is not None:
- loss_node, output_node, generated_loss_spec = _find_loss_output(
- mod, split.graph, output_loss_value_spec
- )
- if loss_node is not None:
- _insert_stage_symbolic_backward(
- split.graph,
- loss_node,
- output_node,
- )
- split.recompile()
- has_loss_and_backward = True
- logger.debug("Pipeline is in training mode, backward pass generated")
- else:
- raise RuntimeError(
- f"Did not find any loss value according to {output_loss_value_spec=}"
- )
- else:
- logger.debug("Pipeline is in inference mode, backward pass not generated")
- logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
- return Pipe(
- split,
- num_stages,
- has_loss_and_backward,
- generated_loss_spec,
- )
- def print_readable(self):
- """
- Print the pipe in a human-readable format.
- This will print both the root pipe and each stage module.
- """
- self.split_gm.print_readable()
- @staticmethod
- def _trace_with_export(
- mod: torch.nn.Module,
- example_args: Tuple[Any, ...],
- example_kwargs: Optional[Dict[str, Any]] = None,
- ) -> ExportedProgram:
- logger.info("Tracing model ...")
- try:
- ep = torch.export.export(
- mod,
- example_args,
- example_kwargs,
- )
- except Exception as e:
- raise RuntimeError(
- "It seems that we cannot capture your model as a full graph. "
- "Typical reasons include graph breaks, data/shape-dependent "
- "control flow, or missing meta kernels for custom operators. "
- "You can use our manual pipeline interfaces, or try to fix the "
- "graph breaks, see https://pytorch.org/docs/stable/export.html"
- ) from e
- return ep
- @staticmethod
- def from_tracing(
- mod: torch.nn.Module,
- example_args: Tuple[Any, ...],
- example_kwargs: Optional[Dict[str, Any]] = None,
- split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
- ):
- # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
- # stages instead of TRANSMIT'ting it
- multi_use_param_spec = MultiUseParameterConfig.REPLICATE
- # Figure out which output is loss from output_chunk_spec
- output_loss_value_spec: Any = None
- # Deprecated
- """
- if output_chunk_spec is not None:
- output_loss_value_spec = map_aggregate(
- output_chunk_spec, lambda v: isinstance(v, _LossReducer)
- )
- """
- # Trace with export
- exported_program = Pipe._trace_with_export(
- mod,
- example_args,
- example_kwargs,
- )
- pipe = Pipe._from_traced(
- mod,
- exported_program,
- multi_use_param_spec,
- output_loss_value_spec=output_loss_value_spec,
- split_policy=split_policy,
- )
- # Users want the first pipeline stage to accept kwargs if the original
- # program does. This is controlled by the `_codegen` field of the graph,
- # so we make a copy here. Note: we only want the input spec and not the
- # output spec, because the output spec is for the last stage. Maybe a
- # TODO? Not sure yet.
- split = pipe.split_gm
- traced = exported_program.module()
- submod0 = next(iter(split.children()))
- submod0_sign = signature(submod0.forward)
- model_sign = signature(traced.forward)
- if len(model_sign.parameters) != len(submod0_sign.parameters):
- # We don't change the signature of the first stage if it takes
- # different number of args than original model
- logger.info(
- f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
- f"first pipeline stage takes {len(submod0_sign.parameters)}. "
- "Please provide args to respective pipeline stages."
- )
- else:
- # Support kwargs for the first stage
- submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
- # `_replace` is actually not "private" or internal. based on this doc:
- # To prevent conflicts with field names, the method and attribute names
- # start with an underscore
- submod0.graph._codegen.pytree_info = (
- submod0.graph._codegen.pytree_info._replace(out_spec=None)
- )
- submod0.recompile()
- return pipe
- def __str__(self):
- return self.split_gm.__str__()
- def __repr__(self):
- return self.split_gm.__repr__()
- def info(self) -> PipeInfo:
- """
- Get information about the pipe.
- Returns
- -------
- PipeInfo
- A dataclass containing information about the pipe.
- """
- return PipeInfo(
- graph=self.split_gm.graph,
- num_stages=self.num_stages,
- has_loss_and_backward=self.has_loss_and_backward,
- )
- def build_stage(
- self,
- stage_index: int,
- device: torch.device,
- group: Optional[ProcessGroup] = None,
- ) -> _PipelineStage:
- """
- Create a `PipelineStage` given a stage index and distributed group.
- The `PipelineStage` can run with `PipelineSchedule`s.
- """
- # Find stage module
- stage_module = self.get_stage_module(stage_index)
- # Move ops argument to device
- # Today PT2 tracer does not treat `x.device` as a symbolic device;
- # instead, the device of tracing time got burned into the generated
- # code. Here we provide a workaround for users to manually modify the
- # "device" kwarg of operations. Such operation may include:
- # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
- if isinstance(stage_module, torch.fx.GraphModule):
- _modify_graph_op_device(stage_module, device)
- else:
- logger.warning(
- f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
- )
- # Detach pipe info
- # Note: be careful what's included in `pipe_info`. We don't want to keep
- # a reference to `Pipe` or `Pipe.split_gm` which stops python from
- # recycling them. When python recycles them, other stage modules (which
- # are irrelevant to current rank) can be automatically freed.
- pipe_info = self.info()
- return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
- class SplitPoint(Enum):
- BEGINNING = 1
- END = 2
- # For backward compatibility, we kept the PipeSplitWrapper class because `class
- # SplitPoint` used to be defined in this class.
- class PipeSplitWrapper:
- # Create a class alias for BC
- SplitPoint = SplitPoint
- def _split_before_forward(self, *args, **kwargs):
- pipe_split()
- return self._orig_forward(*args, **kwargs)
- def _split_after_forward(self, *args, **kwargs):
- try:
- return self._orig_forward(*args, **kwargs)
- finally:
- pipe_split()
- def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
- # TODO: make this implementation out-of-place?
- for qualname, split_type in spec.items():
- atoms = qualname.split(".")
- predecessor_module = mod
- for i, atom in enumerate(atoms[:-1]):
- try:
- predecessor_module = getattr(predecessor_module, atom)
- except AttributeError as e:
- raise AttributeError(
- f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}'
- ) from e
- mod_to_wrap = getattr(predecessor_module, atoms[-1])
- mod_to_wrap._orig_forward = mod_to_wrap.forward
- if split_type == SplitPoint.BEGINNING:
- mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
- elif split_type == SplitPoint.END:
- mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
- else:
- raise ValueError("Unknown split point type.")
- def pipeline(
- module: torch.nn.Module,
- mb_args: Tuple[Any, ...],
- mb_kwargs: Optional[Dict[str, Any]] = None,
- split_spec: Optional[Dict[str, SplitPoint]] = None,
- split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
- ) -> Pipe:
- """
- Split a module based on a specification.
- See `Pipe` for more details.
- Arguments
- ---------
- module:
- The module to be splitted.
- mb_args:
- Example positional inputs, in micro-batch form.
- mb_kwargs:
- Example keyword inputs, in micro-batch form. (default: `None`)
- split_spec:
- A dictionary using submodule names as split marker. (default: `None`)
- split_policy:
- The policy to use for splitting the module. (default: `None`)
- Returns
- -------
- A pipeline representation of class `Pipe`.
- """
- if split_spec is not None and split_policy is not None:
- raise ValueError(
- "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
- )
- if split_spec is not None:
- # Annotate split points in the module based on user spec
- annotate_split_points(module, split_spec)
- return Pipe.from_tracing(
- mod=module,
- example_args=mb_args,
- example_kwargs=mb_kwargs,
- )
- else:
- # Use split policy
- return Pipe.from_tracing(
- mod=module,
- example_args=mb_args,
- example_kwargs=mb_kwargs,
- split_policy=split_policy,
- )
|