dropout.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. __all__ = ['Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout']
  5. class _DropoutNd(Module):
  6. __constants__ = ['p', 'inplace']
  7. p: float
  8. inplace: bool
  9. def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
  10. super().__init__()
  11. if p < 0 or p > 1:
  12. raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
  13. self.p = p
  14. self.inplace = inplace
  15. def extra_repr(self) -> str:
  16. return f'p={self.p}, inplace={self.inplace}'
  17. class Dropout(_DropoutNd):
  18. r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`.
  19. The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution.
  20. Each channel will be zeroed out independently on every forward call.
  21. This has proven to be an effective technique for regularization and
  22. preventing the co-adaptation of neurons as described in the paper
  23. `Improving neural networks by preventing co-adaptation of feature
  24. detectors`_ .
  25. Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
  26. training. This means that during evaluation the module simply computes an
  27. identity function.
  28. Args:
  29. p: probability of an element to be zeroed. Default: 0.5
  30. inplace: If set to ``True``, will do this operation in-place. Default: ``False``
  31. Shape:
  32. - Input: :math:`(*)`. Input can be of any shape
  33. - Output: :math:`(*)`. Output is of the same shape as input
  34. Examples::
  35. >>> m = nn.Dropout(p=0.2)
  36. >>> input = torch.randn(20, 16)
  37. >>> output = m(input)
  38. .. _Improving neural networks by preventing co-adaptation of feature
  39. detectors: https://arxiv.org/abs/1207.0580
  40. """
  41. def forward(self, input: Tensor) -> Tensor:
  42. return F.dropout(input, self.p, self.training, self.inplace)
  43. class Dropout1d(_DropoutNd):
  44. r"""Randomly zero out entire channels.
  45. A channel is a 1D feature map,
  46. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  47. batched input is a 1D tensor :math:`\text{input}[i, j]`.
  48. Each channel will be zeroed out independently on every forward call with
  49. probability :attr:`p` using samples from a Bernoulli distribution.
  50. Usually the input comes from :class:`nn.Conv1d` modules.
  51. As described in the paper
  52. `Efficient Object Localization Using Convolutional Networks`_ ,
  53. if adjacent pixels within feature maps are strongly correlated
  54. (as is normally the case in early convolution layers) then i.i.d. dropout
  55. will not regularize the activations and will otherwise just result
  56. in an effective learning rate decrease.
  57. In this case, :func:`nn.Dropout1d` will help promote independence between
  58. feature maps and should be used instead.
  59. Args:
  60. p (float, optional): probability of an element to be zero-ed.
  61. inplace (bool, optional): If set to ``True``, will do this operation
  62. in-place
  63. Shape:
  64. - Input: :math:`(N, C, L)` or :math:`(C, L)`.
  65. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
  66. Examples::
  67. >>> m = nn.Dropout1d(p=0.2)
  68. >>> input = torch.randn(20, 16, 32)
  69. >>> output = m(input)
  70. .. _Efficient Object Localization Using Convolutional Networks:
  71. https://arxiv.org/abs/1411.4280
  72. """
  73. def forward(self, input: Tensor) -> Tensor:
  74. return F.dropout1d(input, self.p, self.training, self.inplace)
  75. class Dropout2d(_DropoutNd):
  76. r"""Randomly zero out entire channels.
  77. A channel is a 2D feature map,
  78. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  79. batched input is a 2D tensor :math:`\text{input}[i, j]`.
  80. Each channel will be zeroed out independently on every forward call with
  81. probability :attr:`p` using samples from a Bernoulli distribution.
  82. Usually the input comes from :class:`nn.Conv2d` modules.
  83. As described in the paper
  84. `Efficient Object Localization Using Convolutional Networks`_ ,
  85. if adjacent pixels within feature maps are strongly correlated
  86. (as is normally the case in early convolution layers) then i.i.d. dropout
  87. will not regularize the activations and will otherwise just result
  88. in an effective learning rate decrease.
  89. In this case, :func:`nn.Dropout2d` will help promote independence between
  90. feature maps and should be used instead.
  91. Args:
  92. p (float, optional): probability of an element to be zero-ed.
  93. inplace (bool, optional): If set to ``True``, will do this operation
  94. in-place
  95. .. warning ::
  96. Due to historical reasons, this class will perform 1D channel-wise dropout
  97. for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
  98. support inputs without a batch dimension of shape :math:`(C, H, W)`. This
  99. behavior will change in a future release to interpret 3D inputs as no-batch-dim
  100. inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
  101. Shape:
  102. - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
  103. - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
  104. Examples::
  105. >>> m = nn.Dropout2d(p=0.2)
  106. >>> input = torch.randn(20, 16, 32, 32)
  107. >>> output = m(input)
  108. .. _Efficient Object Localization Using Convolutional Networks:
  109. https://arxiv.org/abs/1411.4280
  110. """
  111. def forward(self, input: Tensor) -> Tensor:
  112. return F.dropout2d(input, self.p, self.training, self.inplace)
  113. class Dropout3d(_DropoutNd):
  114. r"""Randomly zero out entire channels.
  115. A channel is a 3D feature map,
  116. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  117. batched input is a 3D tensor :math:`\text{input}[i, j]`.
  118. Each channel will be zeroed out independently on every forward call with
  119. probability :attr:`p` using samples from a Bernoulli distribution.
  120. Usually the input comes from :class:`nn.Conv3d` modules.
  121. As described in the paper
  122. `Efficient Object Localization Using Convolutional Networks`_ ,
  123. if adjacent pixels within feature maps are strongly correlated
  124. (as is normally the case in early convolution layers) then i.i.d. dropout
  125. will not regularize the activations and will otherwise just result
  126. in an effective learning rate decrease.
  127. In this case, :func:`nn.Dropout3d` will help promote independence between
  128. feature maps and should be used instead.
  129. Args:
  130. p (float, optional): probability of an element to be zeroed.
  131. inplace (bool, optional): If set to ``True``, will do this operation
  132. in-place
  133. Shape:
  134. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  135. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  136. Examples::
  137. >>> m = nn.Dropout3d(p=0.2)
  138. >>> input = torch.randn(20, 16, 4, 32, 32)
  139. >>> output = m(input)
  140. .. _Efficient Object Localization Using Convolutional Networks:
  141. https://arxiv.org/abs/1411.4280
  142. """
  143. def forward(self, input: Tensor) -> Tensor:
  144. return F.dropout3d(input, self.p, self.training, self.inplace)
  145. class AlphaDropout(_DropoutNd):
  146. r"""Applies Alpha Dropout over the input.
  147. Alpha Dropout is a type of Dropout that maintains the self-normalizing
  148. property.
  149. For an input with zero mean and unit standard deviation, the output of
  150. Alpha Dropout maintains the original mean and standard deviation of the
  151. input.
  152. Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
  153. that the outputs have zero mean and unit standard deviation.
  154. During training, it randomly masks some of the elements of the input
  155. tensor with probability *p* using samples from a bernoulli distribution.
  156. The elements to masked are randomized on every forward call, and scaled
  157. and shifted to maintain zero mean and unit standard deviation.
  158. During evaluation the module simply computes an identity function.
  159. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  160. Args:
  161. p (float): probability of an element to be dropped. Default: 0.5
  162. inplace (bool, optional): If set to ``True``, will do this operation
  163. in-place
  164. Shape:
  165. - Input: :math:`(*)`. Input can be of any shape
  166. - Output: :math:`(*)`. Output is of the same shape as input
  167. Examples::
  168. >>> m = nn.AlphaDropout(p=0.2)
  169. >>> input = torch.randn(20, 16)
  170. >>> output = m(input)
  171. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  172. """
  173. def forward(self, input: Tensor) -> Tensor:
  174. return F.alpha_dropout(input, self.p, self.training)
  175. class FeatureAlphaDropout(_DropoutNd):
  176. r"""Randomly masks out entire channels.
  177. A channel is a feature map,
  178. e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
  179. is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of
  180. setting activations to zero, as in regular Dropout, the activations are set
  181. to the negative saturation value of the SELU activation function. More details
  182. can be found in the paper `Self-Normalizing Neural Networks`_ .
  183. Each element will be masked independently for each sample on every forward
  184. call with probability :attr:`p` using samples from a Bernoulli distribution.
  185. The elements to be masked are randomized on every forward call, and scaled
  186. and shifted to maintain zero mean and unit variance.
  187. Usually the input comes from :class:`nn.AlphaDropout` modules.
  188. As described in the paper
  189. `Efficient Object Localization Using Convolutional Networks`_ ,
  190. if adjacent pixels within feature maps are strongly correlated
  191. (as is normally the case in early convolution layers) then i.i.d. dropout
  192. will not regularize the activations and will otherwise just result
  193. in an effective learning rate decrease.
  194. In this case, :func:`nn.AlphaDropout` will help promote independence between
  195. feature maps and should be used instead.
  196. Args:
  197. p (float, optional): probability of an element to be zeroed. Default: 0.5
  198. inplace (bool, optional): If set to ``True``, will do this operation
  199. in-place
  200. Shape:
  201. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  202. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  203. Examples::
  204. >>> m = nn.FeatureAlphaDropout(p=0.2)
  205. >>> input = torch.randn(20, 16, 4, 32, 32)
  206. >>> output = m(input)
  207. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  208. .. _Efficient Object Localization Using Convolutional Networks:
  209. https://arxiv.org/abs/1411.4280
  210. """
  211. def forward(self, input: Tensor) -> Tensor:
  212. return F.feature_alpha_dropout(input, self.p, self.training)