post_localSGD_optimizer.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. import torch
  4. import torch.distributed.algorithms.model_averaging.averagers as averagers
  5. class PostLocalSGDOptimizer(torch.optim.Optimizer):
  6. r"""
  7. Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
  8. This optimizer runs local optimizer at every step.
  9. After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
  10. Args:
  11. optim: The local optimizer.
  12. averager: A model averager instance to run post-localSGD algorithm.
  13. Example::
  14. >>> # xdoctest: +SKIP("undefined variables")
  15. >>> import torch
  16. >>> import torch.distributed as dist
  17. >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
  18. >>> import torch.nn as nn
  19. >>> from torch.distributed.optim import PostLocalSGDOptimizer
  20. >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
  21. >>> PostLocalSGDState,
  22. >>> post_localSGD_hook,
  23. >>> )
  24. >>>
  25. >>> model = nn.parallel.DistributedDataParallel(
  26. >>> module, device_ids=[rank], output_device=rank
  27. >>> )
  28. >>>
  29. >>> # Register a post-localSGD communication hook.
  30. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
  31. >>> model.register_comm_hook(state, post_localSGD_hook)
  32. >>>
  33. >>> # Create a post-localSGD optimizer that wraps a local optimizer.
  34. >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
  35. >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
  36. >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
  37. >>> opt = PostLocalSGDOptimizer(
  38. >>> optim=local_optim,
  39. >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
  40. >>> )
  41. >>>
  42. >>> # In the first 100 steps, DDP runs global gradient averaging at every step.
  43. >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
  44. >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
  45. >>> for step in range(0, 200):
  46. >>> opt.zero_grad()
  47. >>> loss = loss_fn(output, labels)
  48. >>> loss.backward()
  49. >>> opt.step()
  50. """
  51. def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
  52. self.optim = optim
  53. self.param_groups = self.optim.param_groups
  54. self.averager = averager
  55. @property
  56. def state(self):
  57. return self.optim.state
  58. def __repr__(self):
  59. return self.optim.__repr__()
  60. def state_dict(self):
  61. r"""
  62. This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
  63. but adds an extra entry to record model averager's step to the checkpoint
  64. to ensure reload does not cause unnecessary warm up again.
  65. """
  66. optim_state_dict = self.optim.state_dict()
  67. optim_state_dict["step"] = self.averager.step
  68. return optim_state_dict
  69. def load_state_dict(self, state_dict):
  70. r"""
  71. This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
  72. but also restores model averager's step value to the one
  73. saved in the provided ``state_dict``.
  74. If there is no ``"step"`` entry in ``state_dict``,
  75. it will raise a warning and initialize the model averager's step to 0.
  76. """
  77. self.optim.load_state_dict(state_dict)
  78. if "step" in state_dict:
  79. self.averager.step = state_dict["step"]
  80. else:
  81. warnings.warn(
  82. "Loaded state dict does not contain a step counter for an averager. "
  83. "Setting step counter to 0."
  84. )
  85. self.averager.step = 0
  86. def step(self):
  87. r"""
  88. Performs a single optimization step (parameter update).
  89. """
  90. self.optim.step()
  91. self.averager.average_parameters(params=self.param_groups)
  92. def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
  93. self.optim.zero_grad(set_to_none=set_to_none)
  94. def add_param_group(self, param_group):
  95. self.optim.add_param_group(param_group)