functional_adamax.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # mypy: allow-untyped-defs
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. __all__: List[str] = []
  7. # Define a TorchScript compatible Functional Adamax Optimizer
  8. # where we use these optimizer in a functional way.
  9. # Instead of using the `param.grad` when updating parameters,
  10. # we explicitly allow the distributed optimizer pass gradients to
  11. # the `step` function. In this way, we could separate the gradients
  12. # and parameters and allow multithreaded trainer to update the
  13. # parameters without data traces on accumulating to the same .grad.
  14. # NOTE: This should be only used by distributed optimizer internals
  15. # and not meant to expose to the user.
  16. @torch.jit.script
  17. class _FunctionalAdamax:
  18. def __init__(
  19. self,
  20. params: List[Tensor],
  21. lr: float = 1e-3,
  22. betas: Tuple[float, float] = (0.9, 0.999),
  23. eps: float = 1e-8,
  24. weight_decay: float = 0.0,
  25. foreach: bool = False,
  26. maximize: bool = False,
  27. _allow_empty_param_list: bool = False,
  28. ):
  29. if not 0.0 <= lr:
  30. raise ValueError(f"Invalid learning rate: {lr}")
  31. if not 0.0 <= eps:
  32. raise ValueError(f"Invalid epsilon value: {eps}")
  33. if not 0.0 <= betas[0] < 1.0:
  34. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  35. if not 0.0 <= betas[1] < 1.0:
  36. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  37. if not 0.0 <= weight_decay:
  38. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  39. self.defaults = {
  40. "lr": lr,
  41. "eps": eps,
  42. "beta1": betas[0],
  43. "beta2": betas[1],
  44. "weight_decay": weight_decay,
  45. }
  46. self.foreach = foreach
  47. self.maximize = maximize
  48. self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
  49. if len(params) == 0 and not _allow_empty_param_list:
  50. raise ValueError("optimizer got an empty parameter list")
  51. # NOTE: we only have one param_group and don't allow user to add additional
  52. # param group as it's not a common use case.
  53. self.param_group = {"params": params}
  54. def step(self, gradients: List[Optional[Tensor]]):
  55. params = self.param_group["params"]
  56. params_with_grad = []
  57. grads = []
  58. exp_avgs = []
  59. exp_infs = []
  60. state_steps: List[Tensor] = []
  61. if len(params) != len(gradients):
  62. raise ValueError(
  63. "the gradients passed in does not equal to the size of the parameters!"
  64. + f"Params length: {len(params)}. "
  65. + f"Gradients length: {len(gradients)}"
  66. )
  67. has_complex = False
  68. for param, gradient in zip(self.param_group["params"], gradients):
  69. if gradient is not None:
  70. has_complex |= torch.is_complex(param)
  71. params_with_grad.append(param)
  72. grads.append(gradient)
  73. # Lazy state initialization
  74. if param not in self.state:
  75. self.state[param] = {}
  76. state = self.state[param]
  77. state["step"] = torch.tensor(0.0)
  78. # Exponential moving average of gradient values
  79. state["exp_avg"] = torch.zeros_like(
  80. param, memory_format=torch.preserve_format
  81. )
  82. # Exponential moving average of squared gradient values
  83. state["exp_inf"] = torch.zeros_like(
  84. param, memory_format=torch.preserve_format
  85. )
  86. state = self.state[param]
  87. exp_avgs.append(state["exp_avg"])
  88. exp_infs.append(state["exp_inf"])
  89. state_steps.append(state["step"])
  90. with torch.no_grad():
  91. F.adamax(
  92. params_with_grad,
  93. grads,
  94. exp_avgs,
  95. exp_infs,
  96. state_steps,
  97. eps=self.defaults["eps"],
  98. beta1=self.defaults["beta1"],
  99. beta2=self.defaults["beta2"],
  100. lr=self.defaults["lr"],
  101. weight_decay=self.defaults["weight_decay"],
  102. foreach=self.foreach,
  103. maximize=self.maximize,
  104. has_complex=has_complex,
  105. )