distributed.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. # mypy: ignore-errors
  2. import logging
  3. import traceback
  4. from dataclasses import dataclass, field
  5. from typing import Any, List, Optional
  6. from unittest import mock
  7. import torch
  8. from torch import fx
  9. from torch._dynamo.output_graph import GraphCompileReason
  10. from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
  11. from torch._logging import trace_structured
  12. from torch.fx.node import Node
  13. # Regular log messages should go through 'log'.
  14. # ddp_graph_log is a separate artifact logger reserved for dumping graphs.
  15. # See docs/source/logging.rst for more info.
  16. log = logging.getLogger(__name__)
  17. ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
  18. def args_str(args):
  19. # a debug helper
  20. if torch.is_tensor(args):
  21. return f"T[{args.shape}]"
  22. elif isinstance(args, tuple):
  23. return f"tuple({', '.join([args_str(x) for x in args])})"
  24. elif isinstance(args, list):
  25. return f"list({', '.join([args_str(x) for x in args])})"
  26. else:
  27. return str(args)
  28. @dataclass
  29. class Bucket:
  30. size: int = 0
  31. params: List[str] = field(default_factory=list)
  32. nodes: List[fx.Node] = field(default_factory=list)
  33. # param_ids is just used for unit testing
  34. param_ids: List = field(default_factory=list)
  35. # keep track of any buckets that were extended for logging purposes
  36. opcount_increased_to_capture_external_output: int = 0
  37. paramsize_before_opcount_increase: int = 0
  38. def bucket_has_external_output(bucket: Bucket) -> bool:
  39. nodes_in_bucket = set()
  40. # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
  41. # so we don't reverse it here
  42. for node in bucket.nodes:
  43. # assume node.op != output, since those are filtered in the original iteration
  44. nodes_in_bucket.add(node)
  45. for user in node.users:
  46. if user not in nodes_in_bucket:
  47. return True
  48. return False
  49. def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
  50. headers = ("Index", "Size (b)", "Param Names")
  51. rows = []
  52. extended_buckets = []
  53. for idx, bucket in enumerate(reversed(buckets)):
  54. if len(bucket.params) > 0:
  55. rows.append((idx, bucket.size, bucket.params[0]))
  56. for param in bucket.params[1:]:
  57. rows.append((None, None, param))
  58. if bucket.opcount_increased_to_capture_external_output > 0:
  59. extended_buckets.append(
  60. (
  61. idx,
  62. bucket.opcount_increased_to_capture_external_output,
  63. bucket.size - bucket.paramsize_before_opcount_increase,
  64. )
  65. )
  66. if len(rows):
  67. log.info(
  68. "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
  69. bucket_bytes_cap,
  70. len(buckets),
  71. )
  72. if len(extended_buckets):
  73. log.warning(
  74. "Some buckets were extended beyond their requested parameter capacities"
  75. " in order to ensure each subgraph has an output node, required for fx graph partitioning."
  76. " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
  77. " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
  78. " partitions for optimal DDP performance."
  79. )
  80. try:
  81. from tabulate import tabulate
  82. log.debug(
  83. "\nDDPOptimizer produced the following bucket assignments:\n%s",
  84. tabulate(rows, headers=headers, tablefmt="simple_grid"),
  85. )
  86. if len(extended_buckets):
  87. log.warning(
  88. "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
  89. tabulate(
  90. extended_buckets,
  91. headers=("Index", "Extra Ops", "Extra Param Size (b)"),
  92. tablefmt="simple_grid",
  93. ),
  94. )
  95. except ImportError:
  96. log.debug(
  97. "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
  98. )
  99. else:
  100. log.debug("DDPOptimizer captured no parameters and did not split this graph.")
  101. def has_higher_order_op(gm):
  102. # Check if there is a higher order op in the graph
  103. for node in gm.graph.nodes:
  104. if node.op == "get_attr":
  105. maybe_param = getattr(gm, node.target)
  106. if isinstance(maybe_param, torch.fx.GraphModule):
  107. return True
  108. return False
  109. # 3 (lazy compile): Replace submodules with lazily compiling submodule
  110. class SubmoduleReplacer(torch.fx.interpreter.Interpreter):
  111. def __init__(self, module, compiler):
  112. super().__init__(module)
  113. self.compiler = compiler
  114. def lazily_compiled_submod(self, input_mod):
  115. """
  116. Create a wrapper around submodules which:
  117. - lazily compiles each of the partitioned submodules using the user-provided compiler
  118. - unpacks singleton tuples/lists into flat arg
  119. """
  120. class LazilyCompiledModule(torch.nn.Module):
  121. def __init__(self, submod, compiler, unwrap_singleton_tuple):
  122. super().__init__()
  123. self.submod = submod
  124. self.compiler = compiler
  125. self.compiled = False
  126. self.unwrap_singleton_tuple = unwrap_singleton_tuple
  127. def forward(self, *args):
  128. if not self.compiled:
  129. # First compile with args as example_inputs
  130. # These args will be fakeified if using Inductor/AOTAutograd
  131. new_submod = self.compiler(self.submod, args)
  132. del self.submod
  133. self.submod = new_submod
  134. self.compiled = True
  135. self.compiler = None
  136. x = self.submod(*args)
  137. # we must let 'input_mod' return a tuple, to make AOT happy.
  138. # (aot_autograd compile_fn literally requires that the output of a graph it compiles is a tuple).
  139. # however, we don't acutally want this tuple to be returned, since the fx logic that calls the submod
  140. # will again wrap outputs from the submod in a tuple. So we unwrap it, and count on it being re-wrapped
  141. if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
  142. return x[0]
  143. return x
  144. unwrap_singleton_tuple = False
  145. for sn in input_mod.graph.nodes:
  146. if sn.op == "output":
  147. if not isinstance(sn.args[0], tuple):
  148. unwrap_singleton_tuple = True
  149. sn.args = (sn.args,)
  150. input_mod.recompile()
  151. input_mod.compile_subgraph_reason = GraphCompileReason(
  152. "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
  153. " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
  154. [
  155. # it's close to useless to get a real stacktrace here, and quite verbose.
  156. traceback.FrameSummary(__file__, 0, DDPOptimizer),
  157. ],
  158. )
  159. wrapper = LazilyCompiledModule(
  160. input_mod,
  161. self.compiler,
  162. unwrap_singleton_tuple,
  163. )
  164. return wrapper
  165. # We replace the submodules with lazy submodules which compile
  166. # the corresponding submodules when they are run with real values
  167. # Always returns `None` - we do not need to propagate values in order
  168. # to replace submodules.
  169. def run_node(self, n: Node) -> Any:
  170. if n.op == "call_module":
  171. real_mod = self.fetch_attr(n.target)
  172. ddp_graph_log.debug("\n---%s graph---\n%s", n.target, real_mod.graph)
  173. assert len(n.kwargs) == 0, "We assume only args for these modules"
  174. lazily_compiled_submod = self.lazily_compiled_submod(real_mod)
  175. # We update the original (outer) graph with a call into the compiled module
  176. # instead of the uncompiled one.
  177. self.module.delete_submodule(n.target)
  178. n.target = "compiled_" + n.target
  179. self.module.add_submodule(n.target, lazily_compiled_submod)
  180. # 3 (no lazy compile): compile each of the partitioned submodules using the user-provided compiler
  181. class SubmodCompiler(torch.fx.interpreter.Interpreter):
  182. def __init__(self, module, compiler, fake_mode):
  183. super().__init__(module)
  184. self.compiler = compiler
  185. self.fake_mode = fake_mode
  186. def compile_submod(self, input_mod, args, kwargs):
  187. """
  188. Compile the submodule,
  189. using a wrapper to make sure its output is always a tuple,
  190. which is required by AotAutograd based compilers
  191. """
  192. assert len(kwargs) == 0, "We assume only args for these modules"
  193. class WrapperModule(torch.nn.Module):
  194. def __init__(self, submod, unwrap_singleton_tuple):
  195. super().__init__()
  196. self.submod = submod
  197. self.unwrap_singleton_tuple = unwrap_singleton_tuple
  198. def forward(self, *args):
  199. x = self.submod(*args)
  200. # TODO(whc)
  201. # for some reason the isinstance check is necessary if I split one node per submod
  202. # - even though I supposedly wrapped the output in a tuple in those cases, the real
  203. # compiled module was still returning a tensor
  204. if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
  205. return x[0]
  206. return x
  207. unwrap_singleton_tuple = False
  208. for sn in input_mod.graph.nodes:
  209. if sn.op == "output":
  210. if not isinstance(sn.args[0], tuple):
  211. unwrap_singleton_tuple = True
  212. sn.args = (sn.args,)
  213. input_mod.recompile()
  214. input_mod.compile_subgraph_reason = GraphCompileReason(
  215. "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
  216. " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
  217. [
  218. # it's close to useless to get a real stacktrace here, and quite verbose.
  219. traceback.FrameSummary(__file__, 0, DDPOptimizer),
  220. ],
  221. )
  222. wrapper = WrapperModule(
  223. self.compiler(input_mod, args),
  224. unwrap_singleton_tuple,
  225. )
  226. return wrapper
  227. # Note:
  228. #
  229. # The way distributed works today around fake tensors can be somewhat confusing.
  230. # Some of these codepaths are shared in both runtime, and compile time. The presence
  231. # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
  232. #
  233. # A few things to keep in mind:
  234. #
  235. # 1) We invoke `compile_submod` with a real module. The output of that gets stored
  236. # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
  237. #
  238. # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
  239. # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
  240. #
  241. # 3) Fake tensors should always be around during compile time.
  242. #
  243. # 4) Fake tensors should never be around at runtime.
  244. #
  245. # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
  246. # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
  247. def run_node(self, n: Node) -> Any:
  248. args, kwargs = self.fetch_args_kwargs_from_env(n)
  249. new_args = []
  250. assert self.fake_mode
  251. for arg in args:
  252. if isinstance(arg, torch.Tensor) and not isinstance(
  253. arg, torch._subclasses.FakeTensor
  254. ):
  255. new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
  256. else:
  257. new_args.append(arg)
  258. log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
  259. assert isinstance(args, tuple)
  260. assert isinstance(kwargs, dict)
  261. if n.op == "call_module":
  262. real_mod = self.fetch_attr(n.target)
  263. if self.fake_mode:
  264. curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
  265. else:
  266. curr_submod = real_mod
  267. ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
  268. # When calling the compiler on the submod, inputs (new_args) are expected to
  269. # be FakeTensors already since Dynamo would have made them FakeTensors in the
  270. # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
  271. # since this wrapping happens during compilation
  272. # Note: Returning Fake Tensors on First AOT Autograd Call
  273. #
  274. # Inductor will optimize strides of outputs when it deems it profitable.
  275. # For instance, converting to channels last. When we split the graph here
  276. # into multiple inductor compilations, we need to make sure that the
  277. # output strides of one compilation is appropriately passed to the subsequent
  278. # compilations. However, the mapping from inductor output to dynamo output
  279. # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
  280. # subclass handling, etc. In order to replay all this logic we set a flag such that
  281. # the first invocation of inductor in aot_autograd will return Fake Tensors with
  282. # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
  283. # This gives us the appropriately strided outputs here which will reflect runtime strides.
  284. class FakeifyFirstAOTInvocationGuard:
  285. def __init__(self):
  286. self.tc = torch._guards.TracingContext.try_get()
  287. assert self.tc
  288. torch._guards.TracingContext.try_get().fakify_first_call = True
  289. def __del__(self):
  290. self.tc.fakify_first_call = False
  291. # For aot_eager and other backends, tracing context is not set
  292. has_tracing_context = torch._guards.TracingContext.try_get() is not None
  293. if has_tracing_context:
  294. g = FakeifyFirstAOTInvocationGuard()
  295. from torch._dynamo.utils import counters
  296. init = counters["aot_autograd"]["total"]
  297. compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
  298. # TODO - better way of doing this?
  299. # Only aot autograd handles fakifying first call
  300. invoked_aot_autograd = init != counters["aot_autograd"]["total"]
  301. # We update the original (outer) graph with a call into the compiled module
  302. # instead of the uncompiled one.
  303. self.module.delete_submodule(n.target)
  304. n.target = "compiled_" + n.target
  305. self.module.add_submodule(n.target, compiled_submod_real)
  306. # Finally, we have to produce inputs for use compiling the next submodule,
  307. # and these need to be FakeTensors, so we execute the module under fake_mode
  308. # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
  309. with self.fake_mode, mock.patch.object(
  310. self.fake_mode, "allow_non_fake_inputs", True
  311. ):
  312. if has_tracing_context and invoked_aot_autograd:
  313. out = compiled_submod_real(*new_args, **kwargs)
  314. # output should be fake or subclass
  315. assert all(
  316. (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
  317. for t in (out if isinstance(out, (list, tuple)) else [out])
  318. )
  319. return out
  320. else:
  321. return curr_submod(*new_args, **kwargs)
  322. else:
  323. # placeholder or output nodes don't need to get compiled, just executed
  324. return getattr(self, n.op)(n.target, new_args, kwargs)
  325. class DDPOptimizer:
  326. """Note [DDPOptimizer]
  327. DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
  328. breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
  329. the boundaries of gradient-allreduce buckets chosen by DDP.
  330. Background/Motivation
  331. - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
  332. - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
  333. - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
  334. at around the same time during backward and thus can share the same allreduce efficiently
  335. - Allreduces must overlap with backward compute for optimal training performance
  336. - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
  337. operates when individual grads become 'ready'
  338. - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
  339. autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
  340. fused backward function executes, preventing any overlap of compute and communication
  341. Algorithm
  342. - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
  343. this graph in reverse order to determine the true order that gradients will become ready during backward.
  344. - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
  345. and a graph break introduced
  346. - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
  347. into an outer module that is returned to the user
  348. Notes
  349. - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
  350. and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
  351. in eager.
  352. - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
  353. produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
  354. degradation approaching the baseline case where graph-splits are not used, but not worse.
  355. - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
  356. subgraphs being compiled
  357. - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
  358. left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
  359. also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
  360. it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
  361. - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
  362. and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
  363. DDPOptimizer)
  364. Debugging
  365. - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
  366. - In many cases, the log messages are helpful (they show bucket size assignments)-
  367. just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
  368. - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
  369. in a single process (or with torchrun, in multiple processes)
  370. Args:
  371. bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
  372. set to match the equivalent parameter on the original DDP module.
  373. backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
  374. first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
  375. special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
  376. """
  377. def __init__(
  378. self,
  379. bucket_bytes_cap: int,
  380. backend_compile_fn,
  381. first_bucket_cap: Optional[int] = None,
  382. ):
  383. if first_bucket_cap is not None:
  384. self.first_bucket_cap = first_bucket_cap
  385. elif torch.distributed.is_available():
  386. # this constant comes from C10D lib which is not always built
  387. self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
  388. else:
  389. self.first_bucket_cap = bucket_bytes_cap
  390. self.bucket_bytes_cap = bucket_bytes_cap
  391. assert (
  392. self.first_bucket_cap <= self.bucket_bytes_cap
  393. ), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
  394. self.backend_compile_fn = backend_compile_fn
  395. def _ignore_parameter(self, parameter):
  396. return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
  397. def add_module_params_to_bucket(self, mod, bucket, processed_modules, prefix):
  398. processed_modules.add(mod)
  399. for name, param in mod.named_parameters():
  400. if param.requires_grad and not self._ignore_parameter(param):
  401. bucket.size += param.untyped_storage().nbytes()
  402. bucket.params.append(f"{prefix}_{name}")
  403. bucket.param_ids.append(id(param))
  404. def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
  405. """
  406. Implements graph splitting, first determining a set of of buckets by counting
  407. parameter sizes in reverse graph order, then invoking the user/backend compiler
  408. to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
  409. and returns its callable.
  410. """
  411. if has_higher_order_op(gm):
  412. # This indicates presence of a higher order op. For now, we
  413. # have no way to break the higher order op into two buckets.
  414. # Allowing higher order ops in the graph also requires
  415. # changes in the split_module, becuase graph splitter
  416. # currently assumes that all the args of all ops are
  417. # tensors, but in the case of higher order ops, it could be
  418. # a graph module. As a workaround, we are shortcircuiting
  419. raise NotImplementedError(
  420. "DDPOptimizer backend: Found a higher order op in the graph. "
  421. "This is not supported. Please turn off DDP optimizer using "
  422. "torch._dynamo.config.optimize_ddp=False. Note that this can "
  423. "cause performance degradation because there will be one bucket "
  424. "for the entire Dynamo graph. Please refer to this issue - "
  425. "https://github.com/pytorch/pytorch/issues/104674."
  426. )
  427. # 1: compute the partition map according to DDP bucket logic
  428. buckets = [Bucket()] # (size, param_names)
  429. processed_modules = set()
  430. for node in reversed(gm.graph.nodes):
  431. if node.op in ("output", "placeholder"):
  432. continue
  433. if (
  434. buckets[0].size >= self.bucket_bytes_cap
  435. or len(buckets) == 1
  436. and buckets[0].size >= self.first_bucket_cap
  437. ):
  438. if bucket_has_external_output(buckets[0]):
  439. buckets.insert(0, Bucket())
  440. else:
  441. # continue building this bucket past the point of filling its parameter capacity,
  442. # to increase chances it contains at least one node that is either a global output or
  443. # passed as input to a subsequent graph
  444. if buckets[0].opcount_increased_to_capture_external_output == 0:
  445. buckets[0].paramsize_before_opcount_increase = buckets[0].size
  446. buckets[0].opcount_increased_to_capture_external_output += 1
  447. if node.op == "call_module":
  448. target_mod = gm.get_submodule(node.target)
  449. if target_mod not in processed_modules:
  450. self.add_module_params_to_bucket(
  451. target_mod, buckets[0], processed_modules, node.target
  452. )
  453. elif node.op == "call_method":
  454. if isinstance(node.args[0].target, str):
  455. target_mod = None
  456. try:
  457. target_mod = gm.get_submodule(node.args[0].target)
  458. except AttributeError:
  459. pass
  460. if target_mod is not None and target_mod not in processed_modules:
  461. self.add_module_params_to_bucket(
  462. target_mod, buckets[0], processed_modules, node.target
  463. )
  464. elif node.op == "get_attr":
  465. maybe_param = getattr(gm, node.target)
  466. if (
  467. isinstance(maybe_param, torch.nn.Parameter)
  468. and maybe_param.requires_grad
  469. and not self._ignore_parameter(maybe_param)
  470. ):
  471. buckets[0].size += maybe_param.untyped_storage().nbytes()
  472. buckets[0].params.append(node.target)
  473. buckets[0].param_ids.append(id(maybe_param))
  474. # All nodes have to be mapped to a bucket, even if they don't have their own params
  475. # Ignored params still end up in buckets, we just don't count them towards the capacity
  476. buckets[0].nodes.append(node)
  477. if len(buckets) > 1 and buckets[0].size == 0:
  478. # we collected a small preamble graph with ops that don't include parameters, fuse it back
  479. buckets[1].nodes.extend(buckets[0].nodes)
  480. assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
  481. del buckets[0]
  482. # stash buckets for testing/debugging purposes
  483. self.buckets = buckets
  484. pretty_print_buckets(buckets, self.bucket_bytes_cap)
  485. if len(buckets) == 1:
  486. # bypass split/fuse logic if there is only one bucket
  487. return self.backend_compile_fn(gm, example_inputs)
  488. # 2: partition the graphmodule according to bucket capacity
  489. partition_map = {}
  490. for idx, b in enumerate(buckets):
  491. for node in b.nodes:
  492. partition_map[node] = idx
  493. split_gm = fx.passes.split_module.split_module(
  494. gm, None, lambda node: partition_map[node]
  495. )
  496. debug_str = (
  497. f"\n---orig graph---\n{gm.graph}\n"
  498. + f"\n---split graph---\n{split_gm.graph}\n"
  499. )
  500. for name, module in split_gm.named_modules():
  501. if "." not in name and len(name):
  502. # only print the submod graphs, not their children
  503. debug_str += f"\n---{name} graph---\n{module.graph}\n"
  504. debug_str += "\n---------------\n"
  505. ddp_graph_log.debug(debug_str)
  506. trace_structured(
  507. "optimize_ddp_split_graph",
  508. payload_fn=lambda: split_gm.print_readable(print_output=False),
  509. )
  510. for name, module in split_gm.named_modules():
  511. if "." not in name and len(name):
  512. trace_structured(
  513. "optimize_ddp_split_child",
  514. lambda: {"name": name},
  515. payload_fn=lambda: module.print_readable(print_output=False),
  516. )
  517. # NOTE, we want to enable `optimize_ddp_lazy_compile` by default as soon as possible,
  518. # becuase it will fix stride mismatch errors (see motivation: https://github.com/pytorch/pytorch/pull/114154).
  519. # However, lazy compile currently causes shape mismatch in other cases (`test_graph_split_inductor_transpose`)
  520. # and we need to fix them before we can enable it by default.
  521. if not torch._dynamo.config.optimize_ddp_lazy_compile:
  522. # Today, optimize_ddp=True and keep_output_stride=False can lead to silent
  523. # correctness issues. The problem is that ddp_optimizer works by partitioning
  524. # the dynamo graph, sending each subgraph through aot autograd to inductor,
  525. # and creates example inputs by eagerly interpreting each subgraph to get
  526. # an output that with the same metadata that we'd get from eager mode.
  527. # This is a problem though, for torch._inductor.config.keep_output_stride.
  528. # The above config can cause the outputs of the first graph to have
  529. # **different** strides from eager, causing the inputs that we pass
  530. # to the second graph to be wrong.
  531. # To really fix this, we would need to faithfully ask inductor
  532. # what the outputs to each graph it expects are.
  533. fake_mode = detect_fake_mode(example_inputs)
  534. if fake_mode is None:
  535. fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
  536. if torch._dynamo.config.optimize_ddp_lazy_compile:
  537. submod_compiler = SubmoduleReplacer(split_gm, self.backend_compile_fn)
  538. else:
  539. submod_compiler = SubmodCompiler(
  540. split_gm, self.backend_compile_fn, fake_mode
  541. )
  542. submod_compiler.run(*example_inputs)
  543. split_gm.recompile()
  544. ddp_graph_log.debug(
  545. "\n---final graph---\n%s\n---------------\n", split_gm.graph
  546. )
  547. return split_gm