utils.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # mypy: allow-untyped-defs
  2. from typing import Type
  3. from torch import optim
  4. from .functional_adadelta import _FunctionalAdadelta
  5. from .functional_adagrad import _FunctionalAdagrad
  6. from .functional_adam import _FunctionalAdam
  7. from .functional_adamax import _FunctionalAdamax
  8. from .functional_adamw import _FunctionalAdamW
  9. from .functional_rmsprop import _FunctionalRMSprop
  10. from .functional_rprop import _FunctionalRprop
  11. from .functional_sgd import _FunctionalSGD
  12. # dict to map a user passed in optimizer_class to a functional
  13. # optimizer class if we have already defined inside the
  14. # distributed.optim package, this is so that we hide the
  15. # functional optimizer to user and still provide the same API.
  16. functional_optim_map = {
  17. optim.Adagrad: _FunctionalAdagrad,
  18. optim.Adam: _FunctionalAdam,
  19. optim.AdamW: _FunctionalAdamW,
  20. optim.SGD: _FunctionalSGD,
  21. optim.Adadelta: _FunctionalAdadelta,
  22. optim.RMSprop: _FunctionalRMSprop,
  23. optim.Rprop: _FunctionalRprop,
  24. optim.Adamax: _FunctionalAdamax,
  25. }
  26. def register_functional_optim(key, optim):
  27. """
  28. Interface to insert a new functional optimizer to functional_optim_map
  29. ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
  30. need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
  31. Example::
  32. >>> # import the new functional optimizer
  33. >>> # xdoctest: +SKIP
  34. >>> from xyz import fn_optimizer
  35. >>> from torch.distributed.optim.utils import register_functional_optim
  36. >>> fn_optim_key = "XYZ_optim"
  37. >>> register_functional_optim(fn_optim_key, fn_optimizer)
  38. """
  39. if key not in functional_optim_map:
  40. functional_optim_map[key] = optim
  41. def as_functional_optim(optim_cls: Type, *args, **kwargs):
  42. try:
  43. functional_cls = functional_optim_map[optim_cls]
  44. except KeyError as e:
  45. raise ValueError(
  46. f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
  47. ) from e
  48. return _create_functional_optim(functional_cls, *args, **kwargs)
  49. def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
  50. return functional_optim_cls(
  51. [],
  52. *args,
  53. **kwargs,
  54. _allow_empty_param_list=True,
  55. )