_IR.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import copy
  4. import logging
  5. import operator
  6. from collections import defaultdict
  7. from enum import Enum
  8. from inspect import Parameter, signature, Signature
  9. from types import MethodType
  10. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  11. import torch
  12. import torch.fx as fx
  13. from torch.distributed import ProcessGroup
  14. from torch.export import ExportedProgram
  15. from torch.export.unflatten import (
  16. _assign_attr,
  17. _AttrKind,
  18. _sink_params,
  19. InterpreterModule,
  20. )
  21. from torch.fx.node import map_aggregate
  22. from torch.fx.passes.split_module import split_module
  23. from ._backward import _null_coalesce_accumulate, stage_backward
  24. from ._unflatten import _outline_submodules
  25. from ._utils import PipeInfo
  26. from .stage import _PipelineStage
  27. logger = logging.getLogger(__name__)
  28. # TODO:
  29. # 1. investigate gradient sync for shared parameters. how does DDP do it?
  30. # 2. Add parameter movement to split_module
  31. def _find_loss_from_output_and_spec(output_val, spec_val):
  32. if spec_val is False:
  33. return None
  34. if spec_val is True:
  35. if not isinstance(output_val, fx.Node):
  36. raise RuntimeError(
  37. f"Loss spec must specify a dynamic value but got {output_val}"
  38. )
  39. return output_val
  40. if isinstance(spec_val, (tuple, list)):
  41. if not isinstance(output_val, (tuple, list)):
  42. raise RuntimeError(
  43. f"Output value {output_val} must match type of loss specification "
  44. f"{spec_val}"
  45. )
  46. if len(output_val) != len(spec_val):
  47. raise RuntimeError(
  48. f"Output value {output_val} must match length of loss specification "
  49. f"{spec_val}"
  50. )
  51. for out, spec in zip(output_val, spec_val):
  52. loss_val = _find_loss_from_output_and_spec(out, spec)
  53. if loss_val is not None:
  54. return loss_val
  55. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  56. if isinstance(spec_val, dict):
  57. if not isinstance(output_val, dict):
  58. raise RuntimeError(
  59. f"Output value {output_val} must match type of loss specification "
  60. f"{spec_val}"
  61. )
  62. if set(output_val.keys()) != set(spec_val.keys()):
  63. raise RuntimeError(
  64. f"Output value {output_val} must match keys of loss specification "
  65. f"{spec_val}"
  66. )
  67. for k in spec_val:
  68. loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
  69. if loss_val is not None:
  70. return loss_val
  71. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  72. raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
  73. def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
  74. output_nodes = [n for n in g.nodes if n.op == "output"]
  75. assert len(output_nodes) == 1
  76. output_node = output_nodes[0]
  77. output_val = output_node.args[0]
  78. generated_spec: Any = None
  79. if isinstance(mod, TrivialLossWrapper):
  80. # TrivialLossWrapper is pre-defined by PiPPy.
  81. # It has loss as the only output so we can safely assume the first output arg is the loss.
  82. assert len(output_node.args) == 1
  83. loss_node = output_val
  84. generated_spec = TrivialLossWrapper.loss_spec
  85. elif output_loss_value_spec is None:
  86. # Use default spec, i.e. search for "loss" in output values
  87. if isinstance(output_val, dict) and "loss" in output_val.keys():
  88. loss_node = output_val["loss"]
  89. generated_spec = {k: k == "loss" for k in output_val}
  90. else:
  91. loss_node = None
  92. generated_spec = None
  93. else:
  94. loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
  95. generated_spec = output_loss_value_spec
  96. return loss_node, output_node, generated_spec
  97. def _insert_stage_symbolic_backward(
  98. g: fx.Graph,
  99. loss_node: fx.Node,
  100. output_node: fx.Node,
  101. ):
  102. # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
  103. tuples: Dict[fx.Node, Tuple] = {}
  104. for node in reversed(g.nodes):
  105. if node.op == "call_function":
  106. # In the forward pass, only emit placeholder, module calls, and
  107. # getitem calls. If we have a target other than getitem in this
  108. # (forward-only) code, there is a bug.
  109. assert node.target == operator.getitem, (
  110. "Found non-getitem call in forward pass. "
  111. "Please report a bug to PiPPy"
  112. )
  113. assert (
  114. len(node.args) == 2
  115. ), "Found malformed getitem call. Please report a bug to PiPPy"
  116. indexed_value, node_idx = tuple(node.args)
  117. # indexed_value is a collection that we are indexing into. It could
  118. # exist in the tuples map if we've processed another `getitem`
  119. # already.
  120. existing_list_size = (
  121. len(tuples[indexed_value]) if indexed_value in tuples else -1
  122. )
  123. new_list_size = max(node_idx + 1, existing_list_size)
  124. reconstructed_list = [None for _ in range(new_list_size)]
  125. # Copy over existing elements if present
  126. if indexed_value in tuples:
  127. for i, val in enumerate(tuples[indexed_value]):
  128. reconstructed_list[i] = val
  129. # Populate value represented by this node
  130. reconstructed_list[node_idx] = node
  131. tuples[indexed_value] = tuple(reconstructed_list)
  132. # Keep track of nodes that dominate the loss node.
  133. # We will only emit backward operations for nodes that can contribute
  134. # to the specified loss value.
  135. live_nodes = {loss_node: None}
  136. val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
  137. def assign_or_accumulate_grad(forward_node, grad_value):
  138. if forward_node in val_to_grad and forward_node.op != "placeholder":
  139. grad_value = g.call_function(
  140. _null_coalesce_accumulate,
  141. (val_to_grad[forward_node], grad_value),
  142. )
  143. val_to_grad[forward_node] = grad_value
  144. with g.inserting_before(output_node):
  145. for node in reversed(g.nodes):
  146. if node not in live_nodes:
  147. continue
  148. def add_to_live_nodes(n):
  149. live_nodes.setdefault(n, None)
  150. fx.node.map_arg(node.args, add_to_live_nodes)
  151. fx.node.map_arg(node.kwargs, add_to_live_nodes)
  152. if node.op == "call_module":
  153. output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
  154. if node in tuples:
  155. stage_output = tuples[node]
  156. output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
  157. outputs_with_grads_idxs = [
  158. i for i, n in enumerate(tuples[node]) if n in live_nodes
  159. ]
  160. else:
  161. stage_output = (node,)
  162. output_grads = val_to_grad[node]
  163. outputs_with_grads_idxs = [0]
  164. output_grads = (
  165. (output_grads,)
  166. if not isinstance(output_grads, tuple)
  167. else output_grads
  168. )
  169. grad_call = g.call_function(
  170. stage_backward,
  171. kwargs={
  172. "stage_output": stage_output,
  173. "output_grads": output_grads,
  174. "input_values": list(node.all_input_nodes),
  175. "outputs_with_grads_idxs": outputs_with_grads_idxs,
  176. },
  177. )
  178. # Insert backward stage debug info
  179. kwargs_copy = dict(grad_call.kwargs)
  180. grad_call.kwargs = kwargs_copy
  181. grad_call_proxy = fx.Proxy(grad_call)
  182. grads = grad_call_proxy.node
  183. input_nodes = list(node.all_input_nodes)
  184. grads_proxy = fx.Proxy(grads)
  185. for i, input_node in enumerate(input_nodes):
  186. assign_or_accumulate_grad(input_node, grads_proxy[i].node)
  187. return g
  188. class PipeSequential(torch.nn.Sequential):
  189. @staticmethod
  190. def from_sequential(sequential_instance: torch.nn.Sequential):
  191. return PipeSequential(*[copy.copy(m) for m in sequential_instance])
  192. def forward(self, input):
  193. for i, module in enumerate(self):
  194. input = module(input)
  195. if i != len(self) - 1:
  196. pipe_split()
  197. return input
  198. class LossWrapper(torch.nn.Module):
  199. """
  200. LossWrapper is a convenient abstract class that allows you to wrap up both
  201. your model as well as its loss function and specify the connectivity between
  202. the inputs, model, loss function, and output value. Example::
  203. class MyModelWrapper(LossWrapper):
  204. def forward(self, x, targets):
  205. model_out = self.module(x)
  206. loss_value = self.loss_fn(model_out, targets)
  207. return loss_value
  208. The above example defines a connectivity where we expect the forward/loss/backward
  209. training procedure to take two arguments (x and targets), pass x into the module
  210. to get the output of the feedforward computation, pass the model output and the
  211. targets value into the loss function, and get and return the loss value, which will
  212. be backpropagated by PiPPy. The above class would then be instantiated like::
  213. model = ... # instantiate the model
  214. loss_fn = torch.nn.MSELoss() # for the sake of demonstration
  215. wrapper = MyModelWrapper(model, loss_fn)
  216. pipe = Pipe.from_tracing(wrapper, ...)
  217. """
  218. def __init__(self, module, loss_fn):
  219. super().__init__()
  220. self.module = module
  221. self.loss_fn = loss_fn
  222. def forward(self, *args, **kwargs):
  223. raise NotImplementedError(
  224. "This instance of LossWrapper does not have an overridden"
  225. "forward(). Please implement forward() to specify the arguments, "
  226. "connection between the module and loss, and loss output "
  227. "value."
  228. )
  229. class TrivialLossWrapper(LossWrapper):
  230. def forward(self, x, targets):
  231. model_out = self.module(x)
  232. return self.loss_fn(model_out, targets)
  233. loss_spec = True
  234. # Pipe model representation
  235. #
  236. # Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
  237. # a single topological ordering of pipeline "stages" that, when run in series,
  238. # constitutes all of the operations of the program. However, unlike `nn.Sequential`,
  239. # Pipe allows non-local usages of values, so long as those uses still respect
  240. # topological ordering. In particular:
  241. #
  242. # 1. Non-local activations. This type of usage can appear in, for example, skip
  243. # connections. These values will be directly transmitted from the "def" stage
  244. # to all stages that use them skipping intermediate stages. During autograd,
  245. # gradients will be propagated back through this skip connection reverse
  246. # to how activations propagated in the forward pass.
  247. # 2. Non-local parameter/module invocations. This occurs when a parameter is used
  248. # in a stage downstream of where it is resident. These values can be carried
  249. # forward similarly to (1), but in addition one might want to replicate the
  250. # value on multiple stages. Gradients for these shared parameters will be
  251. # accumulated separately on each stage, but there will be an additional
  252. # gradient accumulation before the optimizer step.
  253. # Register `_pipe_split()` as an ATen operator. This is required for Export to
  254. # preserve this marker in the graph.
  255. torch.library.define("pippy::_pipe_split", "() -> ()")
  256. @torch.library.impl("pippy::_pipe_split", "BackendSelect")
  257. def _pipe_split():
  258. return None
  259. @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
  260. def _pipe_split(): # noqa: F811
  261. return None
  262. # Add an alias for convenience
  263. aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
  264. # Ask Export to preserve the `_pipe_split` op.
  265. # See examples in pytorch/torch/fx/node.py
  266. fx.node._side_effectful_functions.add(aten_pipe_split_alias)
  267. # User facing API
  268. def pipe_split():
  269. """
  270. pipe_split is a special operator that is used to mark the boundary between
  271. stages in a module. It is used to split the module into stages. It is a
  272. no-op if your annotated module is run eagerly.
  273. Example:
  274. >>> # xdoctest: +SKIP
  275. >>> def forward(self, x):
  276. >>> x = torch.mm(x, self.mm_param)
  277. >>> x = torch.relu(x)
  278. >>> pipe_split()
  279. >>> x = self.lin(x)
  280. >>> return x
  281. The above example will be split into two stages.
  282. """
  283. return torch.ops.pippy._pipe_split()
  284. class MultiUseParameterConfig(Enum):
  285. TRANSMIT = 1
  286. REPLICATE = 2
  287. MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
  288. class DetachExecutor(fx.Interpreter):
  289. """
  290. Special interpreter to run the split_gm in testing that detaches all inputs to
  291. a module invocation. This is needed so that the values at the boundary are
  292. leaf modules in autograd execution.
  293. """
  294. def __init__(self, module, garbage_collect_values=True):
  295. garbage_collect_values = False
  296. super().__init__(module, garbage_collect_values)
  297. self.value_remap = {}
  298. def run(self, *args, initial_env=None):
  299. self.value_remap = {}
  300. return super().run(*args, initial_env=initial_env)
  301. def call_module(self, target, args, kwargs):
  302. def detach_tensors(a):
  303. if isinstance(a, torch.Tensor) and a.requires_grad:
  304. if a not in self.value_remap:
  305. new_val = a.detach().requires_grad_(True)
  306. self.value_remap[a] = new_val
  307. return self.value_remap[a]
  308. else:
  309. return a
  310. """
  311. def dont_traverse_size(a):
  312. return type(a) != torch.Size
  313. """
  314. args = map_aggregate(
  315. args,
  316. detach_tensors, # dont_traverse_size
  317. )
  318. kwargs = map_aggregate(
  319. kwargs,
  320. detach_tensors, # dont_traverse_size
  321. )
  322. return super().call_module(target, args, kwargs)
  323. def call_function(self, target, args, kwargs):
  324. # HACK to reroute saved input tensors to point to the detach()ed version
  325. if target == stage_backward:
  326. kwargs = dict(kwargs)
  327. kwargs["input_values"] = [
  328. self.value_remap.get(v, v) for v in kwargs["input_values"]
  329. ]
  330. return super().call_function(target, args, kwargs)
  331. class _NodeReference:
  332. def __init__(self, name):
  333. self.name = name
  334. name: str
  335. class _LinearNodeList:
  336. def __init__(self, node_list):
  337. self.serialize_node_list = []
  338. for node in node_list:
  339. node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
  340. node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
  341. serialize_node = fx.Node(
  342. graph=None,
  343. name=node.name,
  344. op=node.op,
  345. target=node.target,
  346. args=node_args,
  347. kwargs=node_kwargs,
  348. return_type=node.type,
  349. )
  350. serialize_node.meta = copy.copy(node.meta)
  351. self.serialize_node_list.append(serialize_node)
  352. def to_graph(self):
  353. graph = fx.Graph()
  354. ref_str_to_node: Dict[str, fx.Node] = {}
  355. def ref_to_node(arg):
  356. if isinstance(arg, _NodeReference):
  357. return ref_str_to_node[arg.name]
  358. else:
  359. return arg
  360. for node in self.serialize_node_list:
  361. node_args = map_aggregate(node.args, ref_to_node)
  362. node_kwargs = map_aggregate(node.kwargs, ref_to_node)
  363. deser_node = graph.create_node(
  364. op=node.op,
  365. target=node.target,
  366. args=node_args,
  367. kwargs=node_kwargs,
  368. name=node.name,
  369. type_expr=node.type,
  370. )
  371. ref_str_to_node[node.name] = deser_node
  372. return graph
  373. def _direct_serialization_deserialize(body, nodes):
  374. """
  375. Custom `__reduce__` method for serialization.
  376. DO AS I SAY -- NOT AS I DO. This violates the principle that
  377. GraphModules serialize via code export & re-tracing. We allow
  378. for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
  379. TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
  380. these instances to disk will expose internal implementation
  381. details of `fx.Graph` and related data structures and is
  382. NOT advised.
  383. """
  384. class DummyModule(torch.nn.Module):
  385. def __init__(self, body):
  386. super().__init__()
  387. self.__dict__.update(body)
  388. dummy = DummyModule(body)
  389. return fx.GraphModule(dummy, nodes.to_graph())
  390. def _direct_serialization_reduce(self):
  391. serialization_dict = dict(self.__dict__)
  392. serialization_dict.pop("_graph")
  393. return (
  394. _direct_serialization_deserialize,
  395. (serialization_dict, _LinearNodeList(self.graph.nodes)),
  396. )
  397. def _modify_graph_op_device(
  398. gm: torch.fx.GraphModule,
  399. new_device: torch.device,
  400. ):
  401. """
  402. Modify the device argument of all "call_function" nodes in the graph. This
  403. is useful for moving the graph to a different device. In particular for
  404. generator ops, like torch.ones.
  405. """
  406. modified = False
  407. for node in gm.graph.nodes:
  408. if node.op == "call_function":
  409. if "device" in node.kwargs and node.kwargs["device"] != new_device:
  410. logger.debug(
  411. f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
  412. )
  413. node.update_kwarg("device", new_device)
  414. modified = True
  415. elif node.op == "call_module":
  416. # Recursively modify "device" in submodules
  417. submod = gm.get_submodule(node.target)
  418. if isinstance(submod, torch.fx.GraphModule):
  419. _modify_graph_op_device(submod, new_device)
  420. elif isinstance(submod, InterpreterModule):
  421. # If unflattening has been performed, we need to access its graph module by `.graph_module`
  422. _modify_graph_op_device(submod.graph_module, new_device)
  423. else:
  424. logger.warning(
  425. f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
  426. )
  427. if modified:
  428. gm.recompile()
  429. class Pipe(torch.nn.Module):
  430. def __init__(
  431. self,
  432. split_gm: fx.GraphModule,
  433. num_stages: int,
  434. has_loss_and_backward: bool,
  435. loss_spec,
  436. ):
  437. # TODO: is there a way not to hard wire init?
  438. torch.nn.Module.__init__(self)
  439. self.split_gm: fx.GraphModule = split_gm
  440. self.executor: DetachExecutor = DetachExecutor(self.split_gm)
  441. self.num_stages: int = num_stages
  442. self.has_loss_and_backward = has_loss_and_backward
  443. self.loss_spec = loss_spec
  444. for node in split_gm.graph.nodes:
  445. assert (
  446. node.op in {"call_module", "placeholder", "output"}
  447. or (node.op, node.target) == ("call_function", operator.getitem)
  448. or (node.op, node.target) == ("call_method", "backward")
  449. or (node.op, node.target) == ("call_function", stage_backward)
  450. or (node.op, node.target)
  451. == ("call_function", _null_coalesce_accumulate)
  452. ), node
  453. # Detect replicated parameters so we know that we have to do an additional allreduce
  454. # before applying the optimizer
  455. #
  456. # Note that this also handles the case where there were multiple calls to a single
  457. # module from different stages, regardless of whether that module invocation
  458. # was handled by the logic above.
  459. # Map parameter value to a dictionary that maps the user pipeline module
  460. # to the local qualname within that module
  461. params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
  462. for m_qualname, mod in self.split_gm.named_children():
  463. for p_qualname, param in mod.named_parameters():
  464. params_to_users.setdefault(param, {})
  465. params_to_users[param][m_qualname] = p_qualname
  466. self.replicated_params: List[Dict[str, str]] = [
  467. use_mapping
  468. for _, use_mapping in params_to_users.items()
  469. if len(use_mapping) > 1
  470. ]
  471. # We must break the aliasing relationship between the replicated parameters for correct
  472. # numerics in reference runs. If we do not do this, the autograd tape in separate stages
  473. # will have a reference to the same tensor value and will erroneously apply gradient
  474. # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
  475. # values so that we have separate instances.
  476. for param_mapping in self.replicated_params:
  477. for submod_name, param_qualname in param_mapping.items():
  478. submod = getattr(self.split_gm, submod_name)
  479. atoms = param_qualname.split(".")
  480. for atom in atoms[:-1]:
  481. submod = getattr(submod, atom)
  482. setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
  483. def throw(self, *args, **kwargs):
  484. raise RuntimeError(
  485. "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
  486. )
  487. self.split_gm.forward = throw
  488. # Make submodules use custom direct-serialized GraphModule
  489. i = 0
  490. while True:
  491. try:
  492. name = f"submod_{i}"
  493. submod = getattr(self.split_gm, name)
  494. submod.__class__.__reduce__ = _direct_serialization_reduce
  495. i += 1
  496. except AttributeError:
  497. break
  498. def forward(self, *args, **kwargs):
  499. executor_args = args
  500. if len(kwargs) > 0:
  501. parameters = []
  502. for node in self.split_gm.graph.nodes:
  503. if node.op == "placeholder":
  504. if node.args and len(node.args) > 0:
  505. parameters.append(
  506. Parameter(
  507. node.target,
  508. Parameter.POSITIONAL_OR_KEYWORD,
  509. default=node.args[0],
  510. )
  511. )
  512. else:
  513. parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
  514. param_name = node.target
  515. if node.target.startswith("**"):
  516. parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
  517. param_name = param_name[2:]
  518. elif node.target.startswith("*"):
  519. parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
  520. param_name = param_name[1:]
  521. parameters.append(Parameter(param_name, parameter_kind))
  522. signature = Signature(parameters)
  523. ba = signature.bind(*args, **kwargs)
  524. ba.apply_defaults()
  525. executor_args = ba.arguments.values() # type: ignore[assignment]
  526. res = self.executor.run(*executor_args)
  527. return res
  528. def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
  529. """
  530. Return a stage module corresponding to `stage_idx` of the `pipe`.
  531. """
  532. if stage_idx < 0 or stage_idx >= self.num_stages:
  533. raise ValueError(f"Invalid stage index {stage_idx}!")
  534. return getattr(self.split_gm, f"submod_{stage_idx}")
  535. @staticmethod
  536. def _number_and_count_forward_stages(gm: fx.GraphModule):
  537. num_stages = 0
  538. found_idxs: Dict[int, None] = {}
  539. for node in gm.graph.nodes:
  540. if node.op == "call_module" and node.target.startswith("submod_"):
  541. node.meta["stage_idx"] = int(node.target[len("submod_") :])
  542. found_idxs.setdefault(node.meta["stage_idx"])
  543. num_stages += 1
  544. # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
  545. # Update: the following assert may fail against some torch versions >=
  546. # 2.2.0, as:
  547. # submod_0, submod_1, submod_2, ...
  548. # may be named as
  549. # submod_0, submod_2, submod_4, ...
  550. # TODO: investigate
  551. # assert all(i in found_idxs for i in range(num_stages))
  552. return num_stages
  553. @staticmethod
  554. def _from_traced(
  555. mod: torch.nn.Module,
  556. exported_program: ExportedProgram,
  557. multi_use_param_spec: Optional[MultiUseParamSpec] = None,
  558. output_loss_value_spec=None,
  559. split_policy: Optional[
  560. Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
  561. ] = None,
  562. ):
  563. """
  564. Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
  565. which value in the output of `forward` is the loss value on which PiPPy should apply
  566. backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
  567. you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
  568. a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
  569. ``output_loss_value_spec={'loss': True, 'model_out': False}``
  570. """
  571. traced = exported_program.module()
  572. if split_policy is not None:
  573. logger.info("Auto-splitting model")
  574. traced = split_policy(traced) # type: ignore[arg-type]
  575. logger.debug(traced.print_readable(print_output=False))
  576. # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
  577. # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
  578. # the case (especially with custom tracers), so fix that up here.
  579. get_attr_nodes: Dict[str, fx.Node] = {}
  580. for node in traced.graph.nodes:
  581. if node.op == "get_attr":
  582. get_attr_nodes.setdefault(node.target, node)
  583. if get_attr_nodes[node.target] != node:
  584. node.replace_all_uses_with(get_attr_nodes[node.target])
  585. traced.graph.erase_node(node)
  586. # avoid looking at next node by keeping track of previous pipe_split
  587. prev_pipe_split_idx = -1
  588. pipe_split_nodes_to_erase = set()
  589. for i, node in enumerate(traced.graph.nodes):
  590. if (node.op, node.target) == ("call_function", pipe_split):
  591. if prev_pipe_split_idx == i - 1:
  592. pipe_split_nodes_to_erase.add(node)
  593. prev_pipe_split_idx = i
  594. for node in pipe_split_nodes_to_erase:
  595. traced.graph.erase_node(node)
  596. traced.recompile()
  597. part_idx = 0
  598. def split_callback(n: fx.Node):
  599. nonlocal part_idx
  600. if (n.op, n.target) == (
  601. "call_function",
  602. aten_pipe_split_alias,
  603. ):
  604. logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
  605. part_idx += 1
  606. return part_idx
  607. # TODO: what does split do with module invocations? does it move the modules
  608. # into the submodules?
  609. split = split_module(traced, mod, split_callback)
  610. # a (custom) tracer can produce dead code like orphan get_attr nodes
  611. split.graph.eliminate_dead_code()
  612. # peephole to remove pipe_split
  613. for submodule in split.modules():
  614. if isinstance(submodule, fx.GraphModule):
  615. for node in submodule.graph.nodes:
  616. if (node.op, node.target) == (
  617. "call_function",
  618. aten_pipe_split_alias,
  619. ):
  620. submodule.graph.erase_node(node)
  621. submodule.recompile()
  622. for name, submodule in split.named_children():
  623. if isinstance(submodule, fx.GraphModule):
  624. new_submod = _outline_submodules(submodule.graph)
  625. # Replace old submod
  626. split.register_module(name, new_submod)
  627. # TODO: backport this into split_module
  628. def delete_user_reference(node, user):
  629. """
  630. Delete reference of `node` from `user`'s arg list.
  631. Args:
  632. - node: a `get_attr` node at root.
  633. - user: a submodule node that uses `node`.
  634. """
  635. assert len(user.kwargs) == 0
  636. use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
  637. assert len(use_idxs) == 1
  638. args_copy = list(user.args)
  639. args_copy.pop(use_idxs[0])
  640. user.args = tuple(args_copy)
  641. logger.debug(
  642. f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
  643. )
  644. # A list of param referrals for deferred deletion.
  645. # To be accumulated in `move_param_to_callee`.
  646. to_delete = list()
  647. def _recursive_getattr_with_parent(mod, fqn):
  648. # Returns getattr call given a nested FQN, and the last parent
  649. atoms = fqn.split(".")
  650. for atom in atoms[:-1]:
  651. if not hasattr(mod, atom):
  652. return None, None
  653. mod = getattr(mod, atom)
  654. if not hasattr(mod, atoms[-1]):
  655. return mod, None
  656. attr = getattr(mod, atoms[-1])
  657. return mod, attr
  658. def move_param_to_callee(
  659. root,
  660. callee_name,
  661. param_fqn,
  662. ):
  663. """
  664. Move a parameter from the root module to a submodule.
  665. Args:
  666. root: The root module.
  667. callee_name: The name of the submodule to move the parameter to.
  668. param_fqn: The fully qualified name of the parameter to move.
  669. """
  670. # `atoms` is a list of strings representing the path to the
  671. # parameter in the original model
  672. atoms = param_fqn.split(".")
  673. mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
  674. # Check whether the parameter is a buffer or a parameter
  675. is_buffer = atoms[-1] in mod_itr._buffers
  676. # Check whether the parameter is a tensor
  677. assert isinstance(param_val, torch.Tensor), (
  678. f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
  679. + (
  680. f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
  681. f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
  682. f"usages of '{param_fqn}' in the traced graph."
  683. if isinstance(param_val, torch.nn.Module)
  684. else ""
  685. )
  686. )
  687. # Get submodule
  688. callee = root.get_submodule(callee_name)
  689. assert not hasattr(
  690. callee, param_fqn
  691. ), f"Module {callee_name} already has a parameter named {param_fqn}"
  692. # Assign the parameter to the submodule
  693. if is_buffer:
  694. _assign_attr(
  695. param_val,
  696. callee,
  697. param_fqn,
  698. attr_kind=_AttrKind.BUFFER,
  699. persistent=True, # TODO: handle non-persistent buffer
  700. )
  701. else:
  702. _assign_attr(
  703. param_val,
  704. callee,
  705. param_fqn,
  706. attr_kind=_AttrKind.PARAMETER,
  707. )
  708. logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
  709. # Next step is to replace placeholder of submodule with a get_attr.
  710. # Those placeholders are created by `split_module` inside each
  711. # submodule.
  712. # Update: this step is now moved to `_sink_params` because
  713. # `_sink_params` can do it recursively (i.e. for modules inside
  714. # submodule)
  715. to_delete.append((mod_itr, atoms[-1]))
  716. # Get the list of all parameters in the root module
  717. attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
  718. for node in attr_nodes:
  719. # Check whether the parameter is used in only one submodule
  720. if len(node.users) > 1:
  721. logger.info(
  722. f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
  723. )
  724. for user in node.users:
  725. assert user.op == "call_module"
  726. # Move parameter into submodule
  727. move_param_to_callee(
  728. split,
  729. user.target,
  730. node.target,
  731. )
  732. # [aliasing] store tensor id -> list of FQNs, built from state dict
  733. # Also assign non-persistent buffers
  734. id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
  735. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  736. id_to_fqns[id(tensor)].add(fqn)
  737. for fqn, tensor in mod.named_buffers():
  738. id_to_fqns[id(tensor)].add(fqn)
  739. # After moving the params to their corresponding hierarchies, we also
  740. # need to move the `get_attr` nodes from the root of the graph to those
  741. # hierarchies.
  742. # [aliasing] use id -> fqn mapping to list out all valid FQNs
  743. inputs_to_state: Dict[str, List[str]] = {}
  744. for attr in attr_nodes:
  745. _, tensor = _recursive_getattr_with_parent(mod, attr.target)
  746. fqns = list(id_to_fqns[id(tensor)])
  747. if fqns:
  748. inputs_to_state[attr.name] = fqns
  749. elif attr.target in exported_program.constants: # lifted constants
  750. inputs_to_state[attr.name] = [attr.target]
  751. # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
  752. # We determine this based on whether or not the FQN attribute parent exists.
  753. # i.e. if the last submodule exists, assign the attribute.
  754. added_attributes: Dict[str, List[str]] = defaultdict(list)
  755. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  756. for name, submod in split.named_children():
  757. if isinstance(submod, fx.GraphModule):
  758. parent, child = _recursive_getattr_with_parent(submod, fqn)
  759. if (
  760. parent and child is None
  761. ): # parent exists, attribute doesn't -> assign
  762. added_attributes[name].append(fqn)
  763. setattr(parent, fqn.split(".")[-1], tensor)
  764. # Deferral deletion: Remove the original attributes (to params) from the
  765. # root GraphModule
  766. for mod_itr, last_atom in to_delete:
  767. try:
  768. delattr(mod_itr, last_atom)
  769. except AttributeError:
  770. # This is expected if the parameter is used in multiple stages
  771. pass
  772. # This is done by (1) `_sink_params` at each submodule;
  773. for name, submod in split.named_children():
  774. if isinstance(submod, fx.GraphModule):
  775. _sink_params(submod, inputs_to_state, [])
  776. submod.graph.lint()
  777. submod.recompile()
  778. # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
  779. # After _sink_params() routine has run, clean up unused attributes that we previously added.
  780. # Determine this based on the get_attr nodes - if not used, remove it.
  781. for name, attributes in added_attributes.items():
  782. submod = getattr(split, name)
  783. unused_attributes = set(attributes)
  784. # track used attributes in the submodule, running DFS on subgraph hierarchy
  785. stack = [("", submod)] # (scope, submodule)
  786. while stack:
  787. scope, _mod = stack.pop()
  788. if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
  789. for node in _mod.graph.nodes:
  790. if node.op == "get_attr":
  791. # get_attr might get access deeper level attribute
  792. fqn = scope + "." + node.target if scope else node.target
  793. if fqn in unused_attributes: # used, remove it
  794. unused_attributes.remove(fqn)
  795. for _name, _submod in _mod.named_children():
  796. stack.append((scope + "." + _name if scope else _name, _submod))
  797. # delete unused attributes
  798. for attr in unused_attributes:
  799. mod_itr, atoms = submod, attr.split(".")
  800. for atom in atoms[:-1]:
  801. mod_itr = getattr(mod_itr, atom)
  802. delattr(mod_itr, atoms[-1])
  803. for node in attr_nodes:
  804. # And (2): remove `get_attr` node from submod's arg list
  805. for user in copy.copy(node.users):
  806. assert user.op == "call_module"
  807. delete_user_reference(node, user)
  808. # And (3): remove the `get_attr` node from the root graph.
  809. split.graph.erase_node(node)
  810. split.delete_all_unused_submodules()
  811. split.graph.lint()
  812. split.recompile()
  813. num_stages = Pipe._number_and_count_forward_stages(split)
  814. has_loss_and_backward = False
  815. generated_loss_spec = output_loss_value_spec
  816. if output_loss_value_spec is not None:
  817. loss_node, output_node, generated_loss_spec = _find_loss_output(
  818. mod, split.graph, output_loss_value_spec
  819. )
  820. if loss_node is not None:
  821. _insert_stage_symbolic_backward(
  822. split.graph,
  823. loss_node,
  824. output_node,
  825. )
  826. split.recompile()
  827. has_loss_and_backward = True
  828. logger.debug("Pipeline is in training mode, backward pass generated")
  829. else:
  830. raise RuntimeError(
  831. f"Did not find any loss value according to {output_loss_value_spec=}"
  832. )
  833. else:
  834. logger.debug("Pipeline is in inference mode, backward pass not generated")
  835. logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
  836. return Pipe(
  837. split,
  838. num_stages,
  839. has_loss_and_backward,
  840. generated_loss_spec,
  841. )
  842. def print_readable(self):
  843. """
  844. Print the pipe in a human-readable format.
  845. This will print both the root pipe and each stage module.
  846. """
  847. self.split_gm.print_readable()
  848. @staticmethod
  849. def _trace_with_export(
  850. mod: torch.nn.Module,
  851. example_args: Tuple[Any, ...],
  852. example_kwargs: Optional[Dict[str, Any]] = None,
  853. ) -> ExportedProgram:
  854. logger.info("Tracing model ...")
  855. try:
  856. ep = torch.export.export(
  857. mod,
  858. example_args,
  859. example_kwargs,
  860. )
  861. except Exception as e:
  862. raise RuntimeError(
  863. "It seems that we cannot capture your model as a full graph. "
  864. "Typical reasons include graph breaks, data/shape-dependent "
  865. "control flow, or missing meta kernels for custom operators. "
  866. "You can use our manual pipeline interfaces, or try to fix the "
  867. "graph breaks, see https://pytorch.org/docs/stable/export.html"
  868. ) from e
  869. return ep
  870. @staticmethod
  871. def from_tracing(
  872. mod: torch.nn.Module,
  873. example_args: Tuple[Any, ...],
  874. example_kwargs: Optional[Dict[str, Any]] = None,
  875. split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
  876. ):
  877. # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
  878. # stages instead of TRANSMIT'ting it
  879. multi_use_param_spec = MultiUseParameterConfig.REPLICATE
  880. # Figure out which output is loss from output_chunk_spec
  881. output_loss_value_spec: Any = None
  882. # Deprecated
  883. """
  884. if output_chunk_spec is not None:
  885. output_loss_value_spec = map_aggregate(
  886. output_chunk_spec, lambda v: isinstance(v, _LossReducer)
  887. )
  888. """
  889. # Trace with export
  890. exported_program = Pipe._trace_with_export(
  891. mod,
  892. example_args,
  893. example_kwargs,
  894. )
  895. pipe = Pipe._from_traced(
  896. mod,
  897. exported_program,
  898. multi_use_param_spec,
  899. output_loss_value_spec=output_loss_value_spec,
  900. split_policy=split_policy,
  901. )
  902. # Users want the first pipeline stage to accept kwargs if the original
  903. # program does. This is controlled by the `_codegen` field of the graph,
  904. # so we make a copy here. Note: we only want the input spec and not the
  905. # output spec, because the output spec is for the last stage. Maybe a
  906. # TODO? Not sure yet.
  907. split = pipe.split_gm
  908. traced = exported_program.module()
  909. submod0 = next(iter(split.children()))
  910. submod0_sign = signature(submod0.forward)
  911. model_sign = signature(traced.forward)
  912. if len(model_sign.parameters) != len(submod0_sign.parameters):
  913. # We don't change the signature of the first stage if it takes
  914. # different number of args than original model
  915. logger.info(
  916. f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
  917. f"first pipeline stage takes {len(submod0_sign.parameters)}. "
  918. "Please provide args to respective pipeline stages."
  919. )
  920. else:
  921. # Support kwargs for the first stage
  922. submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
  923. # `_replace` is actually not "private" or internal. based on this doc:
  924. # To prevent conflicts with field names, the method and attribute names
  925. # start with an underscore
  926. submod0.graph._codegen.pytree_info = (
  927. submod0.graph._codegen.pytree_info._replace(out_spec=None)
  928. )
  929. submod0.recompile()
  930. return pipe
  931. def __str__(self):
  932. return self.split_gm.__str__()
  933. def __repr__(self):
  934. return self.split_gm.__repr__()
  935. def info(self) -> PipeInfo:
  936. """
  937. Get information about the pipe.
  938. Returns
  939. -------
  940. PipeInfo
  941. A dataclass containing information about the pipe.
  942. """
  943. return PipeInfo(
  944. graph=self.split_gm.graph,
  945. num_stages=self.num_stages,
  946. has_loss_and_backward=self.has_loss_and_backward,
  947. )
  948. def build_stage(
  949. self,
  950. stage_index: int,
  951. device: torch.device,
  952. group: Optional[ProcessGroup] = None,
  953. ) -> _PipelineStage:
  954. """
  955. Create a `PipelineStage` given a stage index and distributed group.
  956. The `PipelineStage` can run with `PipelineSchedule`s.
  957. """
  958. # Find stage module
  959. stage_module = self.get_stage_module(stage_index)
  960. # Move ops argument to device
  961. # Today PT2 tracer does not treat `x.device` as a symbolic device;
  962. # instead, the device of tracing time got burned into the generated
  963. # code. Here we provide a workaround for users to manually modify the
  964. # "device" kwarg of operations. Such operation may include:
  965. # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
  966. if isinstance(stage_module, torch.fx.GraphModule):
  967. _modify_graph_op_device(stage_module, device)
  968. else:
  969. logger.warning(
  970. f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
  971. )
  972. # Detach pipe info
  973. # Note: be careful what's included in `pipe_info`. We don't want to keep
  974. # a reference to `Pipe` or `Pipe.split_gm` which stops python from
  975. # recycling them. When python recycles them, other stage modules (which
  976. # are irrelevant to current rank) can be automatically freed.
  977. pipe_info = self.info()
  978. return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
  979. class SplitPoint(Enum):
  980. BEGINNING = 1
  981. END = 2
  982. # For backward compatibility, we kept the PipeSplitWrapper class because `class
  983. # SplitPoint` used to be defined in this class.
  984. class PipeSplitWrapper:
  985. # Create a class alias for BC
  986. SplitPoint = SplitPoint
  987. def _split_before_forward(self, *args, **kwargs):
  988. pipe_split()
  989. return self._orig_forward(*args, **kwargs)
  990. def _split_after_forward(self, *args, **kwargs):
  991. try:
  992. return self._orig_forward(*args, **kwargs)
  993. finally:
  994. pipe_split()
  995. def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
  996. # TODO: make this implementation out-of-place?
  997. for qualname, split_type in spec.items():
  998. atoms = qualname.split(".")
  999. predecessor_module = mod
  1000. for i, atom in enumerate(atoms[:-1]):
  1001. try:
  1002. predecessor_module = getattr(predecessor_module, atom)
  1003. except AttributeError as e:
  1004. raise AttributeError(
  1005. f'Specified target {qualname} referenced nonexistent module {".".join(atoms[:i+1])}'
  1006. ) from e
  1007. mod_to_wrap = getattr(predecessor_module, atoms[-1])
  1008. mod_to_wrap._orig_forward = mod_to_wrap.forward
  1009. if split_type == SplitPoint.BEGINNING:
  1010. mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
  1011. elif split_type == SplitPoint.END:
  1012. mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
  1013. else:
  1014. raise ValueError("Unknown split point type.")
  1015. def pipeline(
  1016. module: torch.nn.Module,
  1017. mb_args: Tuple[Any, ...],
  1018. mb_kwargs: Optional[Dict[str, Any]] = None,
  1019. split_spec: Optional[Dict[str, SplitPoint]] = None,
  1020. split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
  1021. ) -> Pipe:
  1022. """
  1023. Split a module based on a specification.
  1024. See `Pipe` for more details.
  1025. Arguments
  1026. ---------
  1027. module:
  1028. The module to be splitted.
  1029. mb_args:
  1030. Example positional inputs, in micro-batch form.
  1031. mb_kwargs:
  1032. Example keyword inputs, in micro-batch form. (default: `None`)
  1033. split_spec:
  1034. A dictionary using submodule names as split marker. (default: `None`)
  1035. split_policy:
  1036. The policy to use for splitting the module. (default: `None`)
  1037. Returns
  1038. -------
  1039. A pipeline representation of class `Pipe`.
  1040. """
  1041. if split_spec is not None and split_policy is not None:
  1042. raise ValueError(
  1043. "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
  1044. )
  1045. if split_spec is not None:
  1046. # Annotate split points in the module based on user spec
  1047. annotate_split_points(module, split_spec)
  1048. return Pipe.from_tracing(
  1049. mod=module,
  1050. example_args=mb_args,
  1051. example_kwargs=mb_kwargs,
  1052. )
  1053. else:
  1054. # Use split policy
  1055. return Pipe.from_tracing(
  1056. mod=module,
  1057. example_args=mb_args,
  1058. example_kwargs=mb_kwargs,
  1059. split_policy=split_policy,
  1060. )