normalization.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # mypy: allow-untyped-defs
  2. import torch
  3. import numbers
  4. from torch.nn.parameter import Parameter
  5. from .module import Module
  6. from ._functions import CrossMapLRN2d as _cross_map_lrn2d
  7. from .. import functional as F
  8. from .. import init
  9. from torch import Tensor, Size
  10. from typing import Union, List, Optional, Tuple
  11. __all__ = ['LocalResponseNorm', 'CrossMapLRN2d', 'LayerNorm', 'GroupNorm', 'RMSNorm']
  12. class LocalResponseNorm(Module):
  13. r"""Applies local response normalization over an input signal.
  14. The input signal is composed of several input planes, where channels occupy the second dimension.
  15. Applies normalization across channels.
  16. .. math::
  17. b_{c} = a_{c}\left(k + \frac{\alpha}{n}
  18. \sum_{c'=\max(0, c-n/2)}^{\min(N-1,c+n/2)}a_{c'}^2\right)^{-\beta}
  19. Args:
  20. size: amount of neighbouring channels used for normalization
  21. alpha: multiplicative factor. Default: 0.0001
  22. beta: exponent. Default: 0.75
  23. k: additive factor. Default: 1
  24. Shape:
  25. - Input: :math:`(N, C, *)`
  26. - Output: :math:`(N, C, *)` (same shape as input)
  27. Examples::
  28. >>> lrn = nn.LocalResponseNorm(2)
  29. >>> signal_2d = torch.randn(32, 5, 24, 24)
  30. >>> signal_4d = torch.randn(16, 5, 7, 7, 7, 7)
  31. >>> output_2d = lrn(signal_2d)
  32. >>> output_4d = lrn(signal_4d)
  33. """
  34. __constants__ = ['size', 'alpha', 'beta', 'k']
  35. size: int
  36. alpha: float
  37. beta: float
  38. k: float
  39. def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1.) -> None:
  40. super().__init__()
  41. self.size = size
  42. self.alpha = alpha
  43. self.beta = beta
  44. self.k = k
  45. def forward(self, input: Tensor) -> Tensor:
  46. return F.local_response_norm(input, self.size, self.alpha, self.beta,
  47. self.k)
  48. def extra_repr(self):
  49. return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
  50. class CrossMapLRN2d(Module):
  51. size: int
  52. alpha: float
  53. beta: float
  54. k: float
  55. def __init__(self, size: int, alpha: float = 1e-4, beta: float = 0.75, k: float = 1) -> None:
  56. super().__init__()
  57. self.size = size
  58. self.alpha = alpha
  59. self.beta = beta
  60. self.k = k
  61. def forward(self, input: Tensor) -> Tensor:
  62. return _cross_map_lrn2d.apply(input, self.size, self.alpha, self.beta,
  63. self.k)
  64. def extra_repr(self) -> str:
  65. return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
  66. _shape_t = Union[int, List[int], Size]
  67. class LayerNorm(Module):
  68. r"""Applies Layer Normalization over a mini-batch of inputs.
  69. This layer implements the operation as described in
  70. the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
  71. .. math::
  72. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  73. The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
  74. is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
  75. is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
  76. the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
  77. :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
  78. :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
  79. The standard-deviation is calculated via the biased estimator, equivalent to
  80. `torch.var(input, unbiased=False)`.
  81. .. note::
  82. Unlike Batch Normalization and Instance Normalization, which applies
  83. scalar scale and bias for each entire channel/plane with the
  84. :attr:`affine` option, Layer Normalization applies per-element scale and
  85. bias with :attr:`elementwise_affine`.
  86. This layer uses statistics computed from input data in both training and
  87. evaluation modes.
  88. Args:
  89. normalized_shape (int or list or torch.Size): input shape from an expected input
  90. of size
  91. .. math::
  92. [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
  93. \times \ldots \times \text{normalized\_shape}[-1]]
  94. If a single integer is used, it is treated as a singleton list, and this module will
  95. normalize over the last dimension which is expected to be of that specific size.
  96. eps: a value added to the denominator for numerical stability. Default: 1e-5
  97. elementwise_affine: a boolean value that when set to ``True``, this module
  98. has learnable per-element affine parameters initialized to ones (for weights)
  99. and zeros (for biases). Default: ``True``.
  100. bias: If set to ``False``, the layer will not learn an additive bias (only relevant if
  101. :attr:`elementwise_affine` is ``True``). Default: ``True``.
  102. Attributes:
  103. weight: the learnable weights of the module of shape
  104. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  105. The values are initialized to 1.
  106. bias: the learnable bias of the module of shape
  107. :math:`\text{normalized\_shape}` when :attr:`elementwise_affine` is set to ``True``.
  108. The values are initialized to 0.
  109. Shape:
  110. - Input: :math:`(N, *)`
  111. - Output: :math:`(N, *)` (same shape as input)
  112. Examples::
  113. >>> # NLP Example
  114. >>> batch, sentence_length, embedding_dim = 20, 5, 10
  115. >>> embedding = torch.randn(batch, sentence_length, embedding_dim)
  116. >>> layer_norm = nn.LayerNorm(embedding_dim)
  117. >>> # Activate module
  118. >>> layer_norm(embedding)
  119. >>>
  120. >>> # Image Example
  121. >>> N, C, H, W = 20, 5, 10, 10
  122. >>> input = torch.randn(N, C, H, W)
  123. >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
  124. >>> # as shown in the image below
  125. >>> layer_norm = nn.LayerNorm([C, H, W])
  126. >>> output = layer_norm(input)
  127. .. image:: ../_static/img/nn/layer_norm.jpg
  128. :scale: 50 %
  129. """
  130. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  131. normalized_shape: Tuple[int, ...]
  132. eps: float
  133. elementwise_affine: bool
  134. def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
  135. bias: bool = True, device=None, dtype=None) -> None:
  136. factory_kwargs = {'device': device, 'dtype': dtype}
  137. super().__init__()
  138. if isinstance(normalized_shape, numbers.Integral):
  139. # mypy error: incompatible types in assignment
  140. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  141. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  142. self.eps = eps
  143. self.elementwise_affine = elementwise_affine
  144. if self.elementwise_affine:
  145. self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
  146. if bias:
  147. self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
  148. else:
  149. self.register_parameter('bias', None)
  150. else:
  151. self.register_parameter('weight', None)
  152. self.register_parameter('bias', None)
  153. self.reset_parameters()
  154. def reset_parameters(self) -> None:
  155. if self.elementwise_affine:
  156. init.ones_(self.weight)
  157. if self.bias is not None:
  158. init.zeros_(self.bias)
  159. def forward(self, input: Tensor) -> Tensor:
  160. return F.layer_norm(
  161. input, self.normalized_shape, self.weight, self.bias, self.eps)
  162. def extra_repr(self) -> str:
  163. return '{normalized_shape}, eps={eps}, ' \
  164. 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
  165. class GroupNorm(Module):
  166. r"""Applies Group Normalization over a mini-batch of inputs.
  167. This layer implements the operation as described in
  168. the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
  169. .. math::
  170. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  171. The input channels are separated into :attr:`num_groups` groups, each containing
  172. ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
  173. :attr:`num_groups`. The mean and standard-deviation are calculated
  174. separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
  175. per-channel affine transform parameter vectors of size :attr:`num_channels` if
  176. :attr:`affine` is ``True``.
  177. The standard-deviation is calculated via the biased estimator, equivalent to
  178. `torch.var(input, unbiased=False)`.
  179. This layer uses statistics computed from input data in both training and
  180. evaluation modes.
  181. Args:
  182. num_groups (int): number of groups to separate the channels into
  183. num_channels (int): number of channels expected in input
  184. eps: a value added to the denominator for numerical stability. Default: 1e-5
  185. affine: a boolean value that when set to ``True``, this module
  186. has learnable per-channel affine parameters initialized to ones (for weights)
  187. and zeros (for biases). Default: ``True``.
  188. Shape:
  189. - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
  190. - Output: :math:`(N, C, *)` (same shape as input)
  191. Examples::
  192. >>> input = torch.randn(20, 6, 10, 10)
  193. >>> # Separate 6 channels into 3 groups
  194. >>> m = nn.GroupNorm(3, 6)
  195. >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
  196. >>> m = nn.GroupNorm(6, 6)
  197. >>> # Put all 6 channels into a single group (equivalent with LayerNorm)
  198. >>> m = nn.GroupNorm(1, 6)
  199. >>> # Activating the module
  200. >>> output = m(input)
  201. """
  202. __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
  203. num_groups: int
  204. num_channels: int
  205. eps: float
  206. affine: bool
  207. def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True,
  208. device=None, dtype=None) -> None:
  209. factory_kwargs = {'device': device, 'dtype': dtype}
  210. super().__init__()
  211. if num_channels % num_groups != 0:
  212. raise ValueError('num_channels must be divisible by num_groups')
  213. self.num_groups = num_groups
  214. self.num_channels = num_channels
  215. self.eps = eps
  216. self.affine = affine
  217. if self.affine:
  218. self.weight = Parameter(torch.empty(num_channels, **factory_kwargs))
  219. self.bias = Parameter(torch.empty(num_channels, **factory_kwargs))
  220. else:
  221. self.register_parameter('weight', None)
  222. self.register_parameter('bias', None)
  223. self.reset_parameters()
  224. def reset_parameters(self) -> None:
  225. if self.affine:
  226. init.ones_(self.weight)
  227. init.zeros_(self.bias)
  228. def forward(self, input: Tensor) -> Tensor:
  229. return F.group_norm(
  230. input, self.num_groups, self.weight, self.bias, self.eps)
  231. def extra_repr(self) -> str:
  232. return '{num_groups}, {num_channels}, eps={eps}, ' \
  233. 'affine={affine}'.format(**self.__dict__)
  234. class RMSNorm(Module):
  235. r"""Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
  236. This layer implements the operation as described in
  237. the paper `Root Mean Square Layer Normalization <https://arxiv.org/pdf/1910.07467.pdf>`__
  238. .. math::
  239. y = \frac{x}{\sqrt{\mathrm{RMS}[x] + \epsilon}} * \gamma
  240. The root mean squared norm is taken over the last ``D`` dimensions, where ``D``
  241. is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
  242. is ``(3, 5)`` (a 2-dimensional shape), the rms norm is computed over
  243. the last 2 dimensions of the input.
  244. Args:
  245. normalized_shape (int or list or torch.Size): input shape from an expected input
  246. of size
  247. .. math::
  248. [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
  249. \times \ldots \times \text{normalized\_shape}[-1]]
  250. If a single integer is used, it is treated as a singleton list, and this module will
  251. normalize over the last dimension which is expected to be of that specific size.
  252. eps: a value added to the denominator for numerical stability. Default: :func:`torch.finfo(x.dtype).eps`
  253. elementwise_affine: a boolean value that when set to ``True``, this module
  254. has learnable per-element affine parameters initialized to ones (for weights)
  255. and zeros (for biases). Default: ``True``.
  256. Shape:
  257. - Input: :math:`(N, *)`
  258. - Output: :math:`(N, *)` (same shape as input)
  259. Examples::
  260. >>> rms_norm = nn.RMSNorm([2, 3])
  261. >>> input = torch.randn(2, 2, 3)
  262. >>> rms_norm(input)
  263. """
  264. __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
  265. normalized_shape: Tuple[int, ...]
  266. eps: Optional[float]
  267. elementwise_affine: bool
  268. def __init__(self, normalized_shape: _shape_t, eps: Optional[float] = None, elementwise_affine: bool = True,
  269. device=None, dtype=None) -> None:
  270. factory_kwargs = {'device': device, 'dtype': dtype}
  271. super().__init__()
  272. if isinstance(normalized_shape, numbers.Integral):
  273. # mypy error: incompatible types in assignment
  274. normalized_shape = (normalized_shape,) # type: ignore[assignment]
  275. self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
  276. self.eps = eps
  277. self.elementwise_affine = elementwise_affine
  278. if self.elementwise_affine:
  279. self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
  280. else:
  281. self.register_parameter('weight', None)
  282. self.reset_parameters()
  283. def reset_parameters(self) -> None:
  284. """
  285. Resets parameters based on their initialization used in __init__.
  286. """
  287. if self.elementwise_affine:
  288. init.ones_(self.weight)
  289. def forward(self, x: torch.Tensor) -> torch.Tensor:
  290. """
  291. Runs forward pass.
  292. """
  293. return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
  294. def extra_repr(self) -> str:
  295. """
  296. Extra information about the module.
  297. """
  298. return '{normalized_shape}, eps={eps}, ' \
  299. 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
  300. # TODO: ContrastiveNorm2d
  301. # TODO: DivisiveNorm2d
  302. # TODO: SubtractiveNorm2d