activations.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections import OrderedDict
  16. import torch
  17. from packaging import version
  18. from torch import Tensor, nn
  19. from .utils import logging
  20. logger = logging.get_logger(__name__)
  21. class PytorchGELUTanh(nn.Module):
  22. """
  23. A fast C implementation of the tanh approximation of the GeLU activation function. See
  24. https://arxiv.org/abs/1606.08415.
  25. This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
  26. match due to rounding errors.
  27. """
  28. def __init__(self):
  29. super().__init__()
  30. if version.parse(torch.__version__) < version.parse("1.12.0"):
  31. raise ImportError(
  32. f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
  33. "PytorchGELUTanh. Please upgrade torch."
  34. )
  35. def forward(self, input: Tensor) -> Tensor:
  36. return nn.functional.gelu(input, approximate="tanh")
  37. class NewGELUActivation(nn.Module):
  38. """
  39. Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
  40. the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
  41. """
  42. def forward(self, input: Tensor) -> Tensor:
  43. return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
  44. class GELUActivation(nn.Module):
  45. """
  46. Original Implementation of the GELU activation function in Google BERT repo when initially created. For
  47. information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
  48. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
  49. Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
  50. """
  51. def __init__(self, use_gelu_python: bool = False):
  52. super().__init__()
  53. if use_gelu_python:
  54. self.act = self._gelu_python
  55. else:
  56. self.act = nn.functional.gelu
  57. def _gelu_python(self, input: Tensor) -> Tensor:
  58. return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
  59. def forward(self, input: Tensor) -> Tensor:
  60. return self.act(input)
  61. class FastGELUActivation(nn.Module):
  62. """
  63. Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
  64. """
  65. def forward(self, input: Tensor) -> Tensor:
  66. return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
  67. class QuickGELUActivation(nn.Module):
  68. """
  69. Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
  70. """
  71. def forward(self, input: Tensor) -> Tensor:
  72. return input * torch.sigmoid(1.702 * input)
  73. class ClippedGELUActivation(nn.Module):
  74. """
  75. Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
  76. it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
  77. https://arxiv.org/abs/2004.09602.
  78. Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
  79. initially created.
  80. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
  81. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
  82. """
  83. def __init__(self, min: float, max: float):
  84. if min > max:
  85. raise ValueError(f"min should be < max (got min: {min}, max: {max})")
  86. super().__init__()
  87. self.min = min
  88. self.max = max
  89. def forward(self, x: Tensor) -> Tensor:
  90. return torch.clip(gelu(x), self.min, self.max)
  91. class AccurateGELUActivation(nn.Module):
  92. """
  93. Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
  94. https://github.com/hendrycks/GELUs
  95. Implemented along with MEGA (Moving Average Equipped Gated Attention)
  96. """
  97. def __init__(self):
  98. super().__init__()
  99. self.precomputed_constant = math.sqrt(2 / math.pi)
  100. def forward(self, input: Tensor) -> Tensor:
  101. return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
  102. class MishActivation(nn.Module):
  103. """
  104. See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
  105. visit the official repository for the paper: https://github.com/digantamisra98/Mish
  106. """
  107. def __init__(self):
  108. super().__init__()
  109. if version.parse(torch.__version__) < version.parse("1.9.0"):
  110. self.act = self._mish_python
  111. else:
  112. self.act = nn.functional.mish
  113. def _mish_python(self, input: Tensor) -> Tensor:
  114. return input * torch.tanh(nn.functional.softplus(input))
  115. def forward(self, input: Tensor) -> Tensor:
  116. return self.act(input)
  117. class LinearActivation(nn.Module):
  118. """
  119. Applies the linear activation function, i.e. forwarding input directly to output.
  120. """
  121. def forward(self, input: Tensor) -> Tensor:
  122. return input
  123. class LaplaceActivation(nn.Module):
  124. """
  125. Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
  126. https://arxiv.org/abs/2209.10655
  127. Inspired by squared relu, but with bounded range and gradient for better stability
  128. """
  129. def forward(self, input, mu=0.707107, sigma=0.282095):
  130. input = (input - mu).div(sigma * math.sqrt(2.0))
  131. return 0.5 * (1.0 + torch.erf(input))
  132. class ReLUSquaredActivation(nn.Module):
  133. """
  134. Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
  135. """
  136. def forward(self, input):
  137. relu_applied = nn.functional.relu(input)
  138. squared = torch.square(relu_applied)
  139. return squared
  140. class ClassInstantier(OrderedDict):
  141. def __getitem__(self, key):
  142. content = super().__getitem__(key)
  143. cls, kwargs = content if isinstance(content, tuple) else (content, {})
  144. return cls(**kwargs)
  145. ACT2CLS = {
  146. "gelu": GELUActivation,
  147. "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
  148. "gelu_fast": FastGELUActivation,
  149. "gelu_new": NewGELUActivation,
  150. "gelu_python": (GELUActivation, {"use_gelu_python": True}),
  151. "gelu_pytorch_tanh": PytorchGELUTanh,
  152. "gelu_accurate": AccurateGELUActivation,
  153. "laplace": LaplaceActivation,
  154. "leaky_relu": nn.LeakyReLU,
  155. "linear": LinearActivation,
  156. "mish": MishActivation,
  157. "quick_gelu": QuickGELUActivation,
  158. "relu": nn.ReLU,
  159. "relu2": ReLUSquaredActivation,
  160. "relu6": nn.ReLU6,
  161. "sigmoid": nn.Sigmoid,
  162. "silu": nn.SiLU,
  163. "swish": nn.SiLU,
  164. "tanh": nn.Tanh,
  165. }
  166. ACT2FN = ClassInstantier(ACT2CLS)
  167. def get_activation(activation_string):
  168. if activation_string in ACT2FN:
  169. return ACT2FN[activation_string]
  170. else:
  171. raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
  172. # For backwards compatibility with: from activations import gelu_python
  173. gelu_python = get_activation("gelu_python")
  174. gelu_new = get_activation("gelu_new")
  175. gelu = get_activation("gelu")
  176. gelu_fast = get_activation("gelu_fast")
  177. quick_gelu = get_activation("quick_gelu")
  178. silu = get_activation("silu")
  179. mish = get_activation("mish")
  180. linear_act = get_activation("linear")