functional_adagrad.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # mypy: allow-untyped-defs
  2. from typing import Dict, List, Optional
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. __all__: List[str] = []
  7. # Define a TorchScript compatible Functional Adagrad Optimizer
  8. # where we use these optimizer in a functional way.
  9. # Instead of using the `param.grad` when updating parameters,
  10. # we explicitly let the user pass gradients to the `step` function
  11. # this is so that we could separate the gradients and parameters
  12. # and allow multithreaded trainer to update the parameters
  13. # without data traces on accumulating to the same .grad.
  14. # NOTE: This should be only used by distributed optimizer internals
  15. # and not meant to expose to the user.
  16. @torch.jit.script
  17. class _FunctionalAdagrad:
  18. def __init__(
  19. self,
  20. params: List[Tensor],
  21. lr: float = 1e-2,
  22. lr_decay: float = 0.0,
  23. weight_decay: float = 0.0,
  24. initial_accumulator_value: float = 0.0,
  25. warmup_lr_multiplier: float = 1.0,
  26. warmup_num_iters: float = 0.0,
  27. eps: float = 1e-10,
  28. coalesce_grad: bool = True,
  29. foreach: bool = False,
  30. fused: bool = False,
  31. maximize: bool = False,
  32. _allow_empty_param_list: bool = False,
  33. ):
  34. self.defaults = {
  35. "lr": lr,
  36. "lr_decay": lr_decay,
  37. "eps": eps,
  38. "weight_decay": weight_decay,
  39. "initial_accumulator_value": initial_accumulator_value,
  40. "warmup_lr_multiplier": warmup_lr_multiplier,
  41. "warmup_num_iters": warmup_num_iters,
  42. }
  43. self.coalesce_grad = coalesce_grad
  44. self.foreach = foreach
  45. self.fused = fused
  46. self.maximize = maximize
  47. self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
  48. if len(params) == 0 and not _allow_empty_param_list:
  49. raise ValueError("optimizer got an empty parameter list")
  50. # NOTE: we only have one param_group and don't allow user to add additional
  51. # param group as it's not a common use case.
  52. self.param_group = {"params": params}
  53. # TODO: no union or any types in TorchScript, make step a scalar tensor instead
  54. # This is also needed by if we want to share_memory on the step across processes
  55. for p in self.param_group["params"]:
  56. self.state[p] = {
  57. "sum": torch.full_like(p.data, initial_accumulator_value),
  58. "step": torch.tensor(0.0),
  59. }
  60. def step(self, gradients: List[Optional[Tensor]]):
  61. params = self.param_group["params"]
  62. params_with_grad = []
  63. grads = []
  64. state_sums = []
  65. state_steps: List[Tensor] = []
  66. if len(params) != len(gradients):
  67. raise ValueError(
  68. "the gradients passed in does not equal to the size of the parameters!"
  69. + f"Params length: {len(params)}. "
  70. + f"Gradients length: {len(gradients)}"
  71. )
  72. has_sparse_grad, has_complex = False, False
  73. for param, gradient in zip(self.param_group["params"], gradients):
  74. if gradient is not None:
  75. has_sparse_grad |= gradient.is_sparse
  76. has_complex |= torch.is_complex(param)
  77. params_with_grad.append(param)
  78. grads.append(gradient)
  79. state = self.state[param]
  80. state_sums.append(state["sum"])
  81. state_steps.append(state["step"])
  82. with torch.no_grad():
  83. F.adagrad(
  84. params,
  85. grads,
  86. state_sums,
  87. state_steps,
  88. lr=self.defaults["lr"],
  89. weight_decay=self.defaults["weight_decay"],
  90. lr_decay=self.defaults["lr_decay"],
  91. eps=self.defaults["eps"],
  92. has_sparse_grad=has_sparse_grad,
  93. foreach=self.foreach,
  94. maximize=self.maximize,
  95. has_complex=has_complex,
  96. fused=self.fused,
  97. grad_scale=None,
  98. found_inf=None,
  99. )