_functional.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # mypy: allow-untyped-defs
  2. r"""Functional interface."""
  3. import math
  4. from typing import List
  5. from torch import Tensor
  6. from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
  7. from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401
  8. from .adam import adam # type: ignore[attr-defined] # noqa: F401
  9. from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
  10. from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
  11. from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
  12. from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
  13. from .radam import radam # type: ignore[attr-defined] # noqa: F401
  14. from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
  15. from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
  16. from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
  17. # TODO: use foreach API in optim._functional to do all the computation
  18. def sparse_adam(
  19. params: List[Tensor],
  20. grads: List[Tensor],
  21. exp_avgs: List[Tensor],
  22. exp_avg_sqs: List[Tensor],
  23. state_steps: List[int],
  24. *,
  25. eps: float,
  26. beta1: float,
  27. beta2: float,
  28. lr: float,
  29. maximize: bool,
  30. ):
  31. r"""Functional API that performs Sparse Adam algorithm computation.
  32. See :class:`~torch.optim.SparseAdam` for details.
  33. """
  34. for i, param in enumerate(params):
  35. grad = grads[i]
  36. grad = grad if not maximize else -grad
  37. grad = grad.coalesce() # the update is non-linear so indices must be unique
  38. grad_indices = grad._indices()
  39. grad_values = grad._values()
  40. if grad_values.numel() == 0:
  41. # Skip update for empty grad
  42. continue
  43. size = grad.size()
  44. exp_avg = exp_avgs[i]
  45. exp_avg_sq = exp_avg_sqs[i]
  46. step = state_steps[i]
  47. def make_sparse(values):
  48. constructor = grad.new
  49. if grad_indices.dim() == 0 or values.dim() == 0:
  50. return constructor().resize_as_(grad)
  51. return constructor(grad_indices, values, size)
  52. # Decay the first and second moment running average coefficient
  53. # old <- b * old + (1 - b) * new
  54. # <==> old += (1 - b) * (new - old)
  55. old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
  56. exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
  57. exp_avg.add_(make_sparse(exp_avg_update_values))
  58. old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
  59. exp_avg_sq_update_values = (
  60. grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
  61. )
  62. exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
  63. # Dense addition again is intended, avoiding another sparse_mask
  64. numer = exp_avg_update_values.add_(old_exp_avg_values)
  65. exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
  66. denom = exp_avg_sq_update_values.sqrt_().add_(eps)
  67. del exp_avg_update_values, exp_avg_sq_update_values
  68. bias_correction1 = 1 - beta1**step
  69. bias_correction2 = 1 - beta2**step
  70. step_size = lr * math.sqrt(bias_correction2) / bias_correction1
  71. param.add_(make_sparse(-step_size * numer.div_(denom)))