activation.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from typing import Optional, Tuple
  4. import torch
  5. from torch import Tensor
  6. from .linear import NonDynamicallyQuantizableLinear
  7. from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
  8. from torch.nn.parameter import Parameter
  9. from .module import Module
  10. from .. import functional as F
  11. __all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
  12. 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
  13. 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Tanhshrink',
  14. 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
  15. class Threshold(Module):
  16. r"""Thresholds each element of the input Tensor.
  17. Threshold is defined as:
  18. .. math::
  19. y =
  20. \begin{cases}
  21. x, &\text{ if } x > \text{threshold} \\
  22. \text{value}, &\text{ otherwise }
  23. \end{cases}
  24. Args:
  25. threshold: The value to threshold at
  26. value: The value to replace with
  27. inplace: can optionally do the operation in-place. Default: ``False``
  28. Shape:
  29. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  30. - Output: :math:`(*)`, same shape as the input.
  31. Examples::
  32. >>> m = nn.Threshold(0.1, 20)
  33. >>> input = torch.randn(2)
  34. >>> output = m(input)
  35. """
  36. __constants__ = ['threshold', 'value', 'inplace']
  37. threshold: float
  38. value: float
  39. inplace: bool
  40. def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
  41. super().__init__()
  42. self.threshold = threshold
  43. self.value = value
  44. self.inplace = inplace
  45. # TODO: check in THNN (if inplace == True, then assert value <= threshold)
  46. def forward(self, input: Tensor) -> Tensor:
  47. return F.threshold(input, self.threshold, self.value, self.inplace)
  48. def extra_repr(self):
  49. inplace_str = ', inplace=True' if self.inplace else ''
  50. return f'threshold={self.threshold}, value={self.value}{inplace_str}'
  51. class ReLU(Module):
  52. r"""Applies the rectified linear unit function element-wise.
  53. :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
  54. Args:
  55. inplace: can optionally do the operation in-place. Default: ``False``
  56. Shape:
  57. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  58. - Output: :math:`(*)`, same shape as the input.
  59. .. image:: ../scripts/activation_images/ReLU.png
  60. Examples::
  61. >>> m = nn.ReLU()
  62. >>> input = torch.randn(2)
  63. >>> output = m(input)
  64. An implementation of CReLU - https://arxiv.org/abs/1603.05201
  65. >>> m = nn.ReLU()
  66. >>> input = torch.randn(2).unsqueeze(0)
  67. >>> output = torch.cat((m(input), m(-input)))
  68. """
  69. __constants__ = ['inplace']
  70. inplace: bool
  71. def __init__(self, inplace: bool = False):
  72. super().__init__()
  73. self.inplace = inplace
  74. def forward(self, input: Tensor) -> Tensor:
  75. return F.relu(input, inplace=self.inplace)
  76. def extra_repr(self) -> str:
  77. inplace_str = 'inplace=True' if self.inplace else ''
  78. return inplace_str
  79. class RReLU(Module):
  80. r"""Applies the randomized leaky rectified linear unit function, element-wise.
  81. Method described in the paper:
  82. `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
  83. The function is defined as:
  84. .. math::
  85. \text{RReLU}(x) =
  86. \begin{cases}
  87. x & \text{if } x \geq 0 \\
  88. ax & \text{ otherwise }
  89. \end{cases}
  90. where :math:`a` is randomly sampled from uniform distribution
  91. :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
  92. evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
  93. Args:
  94. lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
  95. upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
  96. inplace: can optionally do the operation in-place. Default: ``False``
  97. Shape:
  98. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  99. - Output: :math:`(*)`, same shape as the input.
  100. .. image:: ../scripts/activation_images/RReLU.png
  101. Examples::
  102. >>> m = nn.RReLU(0.1, 0.3)
  103. >>> input = torch.randn(2)
  104. >>> output = m(input)
  105. """
  106. __constants__ = ['lower', 'upper', 'inplace']
  107. lower: float
  108. upper: float
  109. inplace: bool
  110. def __init__(
  111. self,
  112. lower: float = 1. / 8,
  113. upper: float = 1. / 3,
  114. inplace: bool = False
  115. ):
  116. super().__init__()
  117. self.lower = lower
  118. self.upper = upper
  119. self.inplace = inplace
  120. def forward(self, input: Tensor) -> Tensor:
  121. return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
  122. def extra_repr(self):
  123. inplace_str = ', inplace=True' if self.inplace else ''
  124. return f'lower={self.lower}, upper={self.upper}{inplace_str}'
  125. class Hardtanh(Module):
  126. r"""Applies the HardTanh function element-wise.
  127. HardTanh is defined as:
  128. .. math::
  129. \text{HardTanh}(x) = \begin{cases}
  130. \text{max\_val} & \text{ if } x > \text{ max\_val } \\
  131. \text{min\_val} & \text{ if } x < \text{ min\_val } \\
  132. x & \text{ otherwise } \\
  133. \end{cases}
  134. Args:
  135. min_val: minimum value of the linear region range. Default: -1
  136. max_val: maximum value of the linear region range. Default: 1
  137. inplace: can optionally do the operation in-place. Default: ``False``
  138. Keyword arguments :attr:`min_value` and :attr:`max_value`
  139. have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
  140. Shape:
  141. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  142. - Output: :math:`(*)`, same shape as the input.
  143. .. image:: ../scripts/activation_images/Hardtanh.png
  144. Examples::
  145. >>> m = nn.Hardtanh(-2, 2)
  146. >>> input = torch.randn(2)
  147. >>> output = m(input)
  148. """
  149. __constants__ = ['min_val', 'max_val', 'inplace']
  150. min_val: float
  151. max_val: float
  152. inplace: bool
  153. def __init__(
  154. self,
  155. min_val: float = -1.,
  156. max_val: float = 1.,
  157. inplace: bool = False,
  158. min_value: Optional[float] = None,
  159. max_value: Optional[float] = None
  160. ) -> None:
  161. super().__init__()
  162. if min_value is not None:
  163. warnings.warn(
  164. "keyword argument `min_value` is deprecated and rename to `min_val`",
  165. FutureWarning,
  166. stacklevel=2,
  167. )
  168. min_val = min_value
  169. if max_value is not None:
  170. warnings.warn(
  171. "keyword argument `max_value` is deprecated and rename to `max_val`",
  172. FutureWarning,
  173. stacklevel=2,
  174. )
  175. max_val = max_value
  176. self.min_val = min_val
  177. self.max_val = max_val
  178. self.inplace = inplace
  179. assert self.max_val > self.min_val
  180. def forward(self, input: Tensor) -> Tensor:
  181. return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
  182. def extra_repr(self) -> str:
  183. inplace_str = ', inplace=True' if self.inplace else ''
  184. return f'min_val={self.min_val}, max_val={self.max_val}{inplace_str}'
  185. class ReLU6(Hardtanh):
  186. r"""Applies the ReLU6 function element-wise.
  187. .. math::
  188. \text{ReLU6}(x) = \min(\max(0,x), 6)
  189. Args:
  190. inplace: can optionally do the operation in-place. Default: ``False``
  191. Shape:
  192. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  193. - Output: :math:`(*)`, same shape as the input.
  194. .. image:: ../scripts/activation_images/ReLU6.png
  195. Examples::
  196. >>> m = nn.ReLU6()
  197. >>> input = torch.randn(2)
  198. >>> output = m(input)
  199. """
  200. def __init__(self, inplace: bool = False):
  201. super().__init__(0., 6., inplace)
  202. def extra_repr(self) -> str:
  203. inplace_str = 'inplace=True' if self.inplace else ''
  204. return inplace_str
  205. class Sigmoid(Module):
  206. r"""Applies the Sigmoid function element-wise.
  207. .. math::
  208. \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
  209. Shape:
  210. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  211. - Output: :math:`(*)`, same shape as the input.
  212. .. image:: ../scripts/activation_images/Sigmoid.png
  213. Examples::
  214. >>> m = nn.Sigmoid()
  215. >>> input = torch.randn(2)
  216. >>> output = m(input)
  217. """
  218. def forward(self, input: Tensor) -> Tensor:
  219. return torch.sigmoid(input)
  220. class Hardsigmoid(Module):
  221. r"""Applies the Hardsigmoid function element-wise.
  222. Hardsigmoid is defined as:
  223. .. math::
  224. \text{Hardsigmoid}(x) = \begin{cases}
  225. 0 & \text{if~} x \le -3, \\
  226. 1 & \text{if~} x \ge +3, \\
  227. x / 6 + 1 / 2 & \text{otherwise}
  228. \end{cases}
  229. Args:
  230. inplace: can optionally do the operation in-place. Default: ``False``
  231. Shape:
  232. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  233. - Output: :math:`(*)`, same shape as the input.
  234. .. image:: ../scripts/activation_images/Hardsigmoid.png
  235. Examples::
  236. >>> m = nn.Hardsigmoid()
  237. >>> input = torch.randn(2)
  238. >>> output = m(input)
  239. """
  240. __constants__ = ['inplace']
  241. inplace: bool
  242. def __init__(self, inplace : bool = False) -> None:
  243. super().__init__()
  244. self.inplace = inplace
  245. def forward(self, input: Tensor) -> Tensor:
  246. return F.hardsigmoid(input, self.inplace)
  247. class Tanh(Module):
  248. r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
  249. Tanh is defined as:
  250. .. math::
  251. \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
  252. Shape:
  253. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  254. - Output: :math:`(*)`, same shape as the input.
  255. .. image:: ../scripts/activation_images/Tanh.png
  256. Examples::
  257. >>> m = nn.Tanh()
  258. >>> input = torch.randn(2)
  259. >>> output = m(input)
  260. """
  261. def forward(self, input: Tensor) -> Tensor:
  262. return torch.tanh(input)
  263. class SiLU(Module):
  264. r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
  265. The SiLU function is also known as the swish function.
  266. .. math::
  267. \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
  268. .. note::
  269. See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
  270. where the SiLU (Sigmoid Linear Unit) was originally coined, and see
  271. `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
  272. in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
  273. a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
  274. where the SiLU was experimented with later.
  275. Shape:
  276. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  277. - Output: :math:`(*)`, same shape as the input.
  278. .. image:: ../scripts/activation_images/SiLU.png
  279. Examples::
  280. >>> m = nn.SiLU()
  281. >>> input = torch.randn(2)
  282. >>> output = m(input)
  283. """
  284. __constants__ = ['inplace']
  285. inplace: bool
  286. def __init__(self, inplace: bool = False):
  287. super().__init__()
  288. self.inplace = inplace
  289. def forward(self, input: Tensor) -> Tensor:
  290. return F.silu(input, inplace=self.inplace)
  291. def extra_repr(self) -> str:
  292. inplace_str = 'inplace=True' if self.inplace else ''
  293. return inplace_str
  294. class Mish(Module):
  295. r"""Applies the Mish function, element-wise.
  296. Mish: A Self Regularized Non-Monotonic Neural Activation Function.
  297. .. math::
  298. \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
  299. .. note::
  300. See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
  301. Shape:
  302. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  303. - Output: :math:`(*)`, same shape as the input.
  304. .. image:: ../scripts/activation_images/Mish.png
  305. Examples::
  306. >>> m = nn.Mish()
  307. >>> input = torch.randn(2)
  308. >>> output = m(input)
  309. """
  310. __constants__ = ['inplace']
  311. inplace: bool
  312. def __init__(self, inplace: bool = False):
  313. super().__init__()
  314. self.inplace = inplace
  315. def forward(self, input: Tensor) -> Tensor:
  316. return F.mish(input, inplace=self.inplace)
  317. def extra_repr(self) -> str:
  318. inplace_str = 'inplace=True' if self.inplace else ''
  319. return inplace_str
  320. class Hardswish(Module):
  321. r"""Applies the Hardswish function, element-wise.
  322. Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
  323. Hardswish is defined as:
  324. .. math::
  325. \text{Hardswish}(x) = \begin{cases}
  326. 0 & \text{if~} x \le -3, \\
  327. x & \text{if~} x \ge +3, \\
  328. x \cdot (x + 3) /6 & \text{otherwise}
  329. \end{cases}
  330. Args:
  331. inplace: can optionally do the operation in-place. Default: ``False``
  332. Shape:
  333. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  334. - Output: :math:`(*)`, same shape as the input.
  335. .. image:: ../scripts/activation_images/Hardswish.png
  336. Examples::
  337. >>> m = nn.Hardswish()
  338. >>> input = torch.randn(2)
  339. >>> output = m(input)
  340. """
  341. __constants__ = ['inplace']
  342. inplace: bool
  343. def __init__(self, inplace : bool = False) -> None:
  344. super().__init__()
  345. self.inplace = inplace
  346. def forward(self, input: Tensor) -> Tensor:
  347. return F.hardswish(input, self.inplace)
  348. class ELU(Module):
  349. r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
  350. Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
  351. Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
  352. ELU is defined as:
  353. .. math::
  354. \text{ELU}(x) = \begin{cases}
  355. x, & \text{ if } x > 0\\
  356. \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
  357. \end{cases}
  358. Args:
  359. alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
  360. inplace: can optionally do the operation in-place. Default: ``False``
  361. Shape:
  362. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  363. - Output: :math:`(*)`, same shape as the input.
  364. .. image:: ../scripts/activation_images/ELU.png
  365. Examples::
  366. >>> m = nn.ELU()
  367. >>> input = torch.randn(2)
  368. >>> output = m(input)
  369. """
  370. __constants__ = ['alpha', 'inplace']
  371. alpha: float
  372. inplace: bool
  373. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  374. super().__init__()
  375. self.alpha = alpha
  376. self.inplace = inplace
  377. def forward(self, input: Tensor) -> Tensor:
  378. return F.elu(input, self.alpha, self.inplace)
  379. def extra_repr(self) -> str:
  380. inplace_str = ', inplace=True' if self.inplace else ''
  381. return f'alpha={self.alpha}{inplace_str}'
  382. class CELU(Module):
  383. r"""Applies the CELU function element-wise.
  384. .. math::
  385. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
  386. More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
  387. Args:
  388. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  389. inplace: can optionally do the operation in-place. Default: ``False``
  390. Shape:
  391. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  392. - Output: :math:`(*)`, same shape as the input.
  393. .. image:: ../scripts/activation_images/CELU.png
  394. Examples::
  395. >>> m = nn.CELU()
  396. >>> input = torch.randn(2)
  397. >>> output = m(input)
  398. .. _`Continuously Differentiable Exponential Linear Units`:
  399. https://arxiv.org/abs/1704.07483
  400. """
  401. __constants__ = ['alpha', 'inplace']
  402. alpha: float
  403. inplace: bool
  404. def __init__(self, alpha: float = 1., inplace: bool = False) -> None:
  405. super().__init__()
  406. self.alpha = alpha
  407. self.inplace = inplace
  408. def forward(self, input: Tensor) -> Tensor:
  409. return F.celu(input, self.alpha, self.inplace)
  410. def extra_repr(self) -> str:
  411. inplace_str = ', inplace=True' if self.inplace else ''
  412. return f'alpha={self.alpha}{inplace_str}'
  413. class SELU(Module):
  414. r"""Applies the SELU function element-wise.
  415. .. math::
  416. \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
  417. with :math:`\alpha = 1.6732632423543772848170429916717` and
  418. :math:`\text{scale} = 1.0507009873554804934193349852946`.
  419. .. warning::
  420. When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
  421. ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
  422. in order to get `Self-Normalizing Neural Networks`_.
  423. See :func:`torch.nn.init.calculate_gain` for more information.
  424. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  425. Args:
  426. inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
  427. Shape:
  428. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  429. - Output: :math:`(*)`, same shape as the input.
  430. .. image:: ../scripts/activation_images/SELU.png
  431. Examples::
  432. >>> m = nn.SELU()
  433. >>> input = torch.randn(2)
  434. >>> output = m(input)
  435. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  436. """
  437. __constants__ = ['inplace']
  438. inplace: bool
  439. def __init__(self, inplace: bool = False) -> None:
  440. super().__init__()
  441. self.inplace = inplace
  442. def forward(self, input: Tensor) -> Tensor:
  443. return F.selu(input, self.inplace)
  444. def extra_repr(self) -> str:
  445. inplace_str = 'inplace=True' if self.inplace else ''
  446. return inplace_str
  447. class GLU(Module):
  448. r"""Applies the gated linear unit function.
  449. :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
  450. of the input matrices and :math:`b` is the second half.
  451. Args:
  452. dim (int): the dimension on which to split the input. Default: -1
  453. Shape:
  454. - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
  455. dimensions
  456. - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
  457. Examples::
  458. >>> m = nn.GLU()
  459. >>> input = torch.randn(4, 2)
  460. >>> output = m(input)
  461. """
  462. __constants__ = ['dim']
  463. dim: int
  464. def __init__(self, dim: int = -1) -> None:
  465. super().__init__()
  466. self.dim = dim
  467. def forward(self, input: Tensor) -> Tensor:
  468. return F.glu(input, self.dim)
  469. def extra_repr(self) -> str:
  470. return f'dim={self.dim}'
  471. class GELU(Module):
  472. r"""Applies the Gaussian Error Linear Units function.
  473. .. math:: \text{GELU}(x) = x * \Phi(x)
  474. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  475. When the approximate argument is 'tanh', Gelu is estimated with:
  476. .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
  477. Args:
  478. approximate (str, optional): the gelu approximation algorithm to use:
  479. ``'none'`` | ``'tanh'``. Default: ``'none'``
  480. Shape:
  481. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  482. - Output: :math:`(*)`, same shape as the input.
  483. .. image:: ../scripts/activation_images/GELU.png
  484. Examples::
  485. >>> m = nn.GELU()
  486. >>> input = torch.randn(2)
  487. >>> output = m(input)
  488. """
  489. __constants__ = ['approximate']
  490. approximate: str
  491. def __init__(self, approximate: str = 'none') -> None:
  492. super().__init__()
  493. self.approximate = approximate
  494. def forward(self, input: Tensor) -> Tensor:
  495. return F.gelu(input, approximate=self.approximate)
  496. def extra_repr(self) -> str:
  497. return f'approximate={repr(self.approximate)}'
  498. class Hardshrink(Module):
  499. r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
  500. Hardshrink is defined as:
  501. .. math::
  502. \text{HardShrink}(x) =
  503. \begin{cases}
  504. x, & \text{ if } x > \lambda \\
  505. x, & \text{ if } x < -\lambda \\
  506. 0, & \text{ otherwise }
  507. \end{cases}
  508. Args:
  509. lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
  510. Shape:
  511. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  512. - Output: :math:`(*)`, same shape as the input.
  513. .. image:: ../scripts/activation_images/Hardshrink.png
  514. Examples::
  515. >>> m = nn.Hardshrink()
  516. >>> input = torch.randn(2)
  517. >>> output = m(input)
  518. """
  519. __constants__ = ['lambd']
  520. lambd: float
  521. def __init__(self, lambd: float = 0.5) -> None:
  522. super().__init__()
  523. self.lambd = lambd
  524. def forward(self, input: Tensor) -> Tensor:
  525. return F.hardshrink(input, self.lambd)
  526. def extra_repr(self) -> str:
  527. return f'{self.lambd}'
  528. class LeakyReLU(Module):
  529. r"""Applies the LeakyReLU function element-wise.
  530. .. math::
  531. \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
  532. or
  533. .. math::
  534. \text{LeakyReLU}(x) =
  535. \begin{cases}
  536. x, & \text{ if } x \geq 0 \\
  537. \text{negative\_slope} \times x, & \text{ otherwise }
  538. \end{cases}
  539. Args:
  540. negative_slope: Controls the angle of the negative slope (which is used for
  541. negative input values). Default: 1e-2
  542. inplace: can optionally do the operation in-place. Default: ``False``
  543. Shape:
  544. - Input: :math:`(*)` where `*` means, any number of additional
  545. dimensions
  546. - Output: :math:`(*)`, same shape as the input
  547. .. image:: ../scripts/activation_images/LeakyReLU.png
  548. Examples::
  549. >>> m = nn.LeakyReLU(0.1)
  550. >>> input = torch.randn(2)
  551. >>> output = m(input)
  552. """
  553. __constants__ = ['inplace', 'negative_slope']
  554. inplace: bool
  555. negative_slope: float
  556. def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
  557. super().__init__()
  558. self.negative_slope = negative_slope
  559. self.inplace = inplace
  560. def forward(self, input: Tensor) -> Tensor:
  561. return F.leaky_relu(input, self.negative_slope, self.inplace)
  562. def extra_repr(self) -> str:
  563. inplace_str = ', inplace=True' if self.inplace else ''
  564. return f'negative_slope={self.negative_slope}{inplace_str}'
  565. class LogSigmoid(Module):
  566. r"""Applies the Logsigmoid function element-wise.
  567. .. math::
  568. \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
  569. Shape:
  570. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  571. - Output: :math:`(*)`, same shape as the input.
  572. .. image:: ../scripts/activation_images/LogSigmoid.png
  573. Examples::
  574. >>> m = nn.LogSigmoid()
  575. >>> input = torch.randn(2)
  576. >>> output = m(input)
  577. """
  578. def forward(self, input: Tensor) -> Tensor:
  579. return F.logsigmoid(input)
  580. class Softplus(Module):
  581. r"""Applies the Softplus function element-wise.
  582. .. math::
  583. \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
  584. SoftPlus is a smooth approximation to the ReLU function and can be used
  585. to constrain the output of a machine to always be positive.
  586. For numerical stability the implementation reverts to the linear function
  587. when :math:`input \times \beta > threshold`.
  588. Args:
  589. beta: the :math:`\beta` value for the Softplus formulation. Default: 1
  590. threshold: values above this revert to a linear function. Default: 20
  591. Shape:
  592. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  593. - Output: :math:`(*)`, same shape as the input.
  594. .. image:: ../scripts/activation_images/Softplus.png
  595. Examples::
  596. >>> m = nn.Softplus()
  597. >>> input = torch.randn(2)
  598. >>> output = m(input)
  599. """
  600. __constants__ = ['beta', 'threshold']
  601. beta: float
  602. threshold: float
  603. def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
  604. super().__init__()
  605. self.beta = beta
  606. self.threshold = threshold
  607. def forward(self, input: Tensor) -> Tensor:
  608. return F.softplus(input, self.beta, self.threshold)
  609. def extra_repr(self) -> str:
  610. return f'beta={self.beta}, threshold={self.threshold}'
  611. class Softshrink(Module):
  612. r"""Applies the soft shrinkage function element-wise.
  613. .. math::
  614. \text{SoftShrinkage}(x) =
  615. \begin{cases}
  616. x - \lambda, & \text{ if } x > \lambda \\
  617. x + \lambda, & \text{ if } x < -\lambda \\
  618. 0, & \text{ otherwise }
  619. \end{cases}
  620. Args:
  621. lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
  622. Shape:
  623. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  624. - Output: :math:`(*)`, same shape as the input.
  625. .. image:: ../scripts/activation_images/Softshrink.png
  626. Examples::
  627. >>> m = nn.Softshrink()
  628. >>> input = torch.randn(2)
  629. >>> output = m(input)
  630. """
  631. __constants__ = ['lambd']
  632. lambd: float
  633. def __init__(self, lambd: float = 0.5) -> None:
  634. super().__init__()
  635. self.lambd = lambd
  636. def forward(self, input: Tensor) -> Tensor:
  637. return F.softshrink(input, self.lambd)
  638. def extra_repr(self) -> str:
  639. return str(self.lambd)
  640. def _check_arg_device(x: Optional[torch.Tensor]) -> bool:
  641. if x is not None:
  642. return x.device.type in ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
  643. return True
  644. def _arg_requires_grad(x: Optional[torch.Tensor]) -> bool:
  645. if x is not None:
  646. return x.requires_grad
  647. return False
  648. def _is_make_fx_tracing():
  649. if not torch.jit.is_scripting():
  650. torch_dispatch_mode_stack = torch.utils._python_dispatch._get_current_dispatch_mode_stack()
  651. return any(type(x) == torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode for x in torch_dispatch_mode_stack)
  652. else:
  653. return False
  654. class MultiheadAttention(Module):
  655. r"""Allows the model to jointly attend to information from different representation subspaces.
  656. Method described in the paper:
  657. `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
  658. Multi-Head Attention is defined as:
  659. .. math::
  660. \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
  661. where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
  662. ``nn.MultiHeadAttention`` will use the optimized implementations of
  663. ``scaled_dot_product_attention()`` when possible.
  664. In addition to support for the new ``scaled_dot_product_attention()``
  665. function, for speeding up Inference, MHA will use
  666. fastpath inference with support for Nested Tensors, iff:
  667. - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
  668. - inputs are batched (3D) with ``batch_first==True``
  669. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
  670. - training is disabled (using ``.eval()``)
  671. - ``add_bias_kv`` is ``False``
  672. - ``add_zero_attn`` is ``False``
  673. - ``kdim`` and ``vdim`` are equal to ``embed_dim``
  674. - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
  675. nor ``attn_mask`` is passed
  676. - autocast is disabled
  677. If the optimized inference fastpath implementation is in use, a
  678. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
  679. ``query``/``key``/``value`` to represent padding more efficiently than using a
  680. padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
  681. will be returned, and an additional speedup proportional to the fraction of the input
  682. that is padding can be expected.
  683. Args:
  684. embed_dim: Total dimension of the model.
  685. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
  686. across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
  687. dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
  688. bias: If specified, adds bias to input / output projection layers. Default: ``True``.
  689. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
  690. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
  691. Default: ``False``.
  692. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
  693. vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
  694. batch_first: If ``True``, then the input and output tensors are provided
  695. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  696. Examples::
  697. >>> # xdoctest: +SKIP
  698. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  699. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  700. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  701. https://arxiv.org/abs/2205.14135
  702. """
  703. __constants__ = ['batch_first']
  704. bias_k: Optional[torch.Tensor]
  705. bias_v: Optional[torch.Tensor]
  706. def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
  707. kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
  708. if embed_dim <= 0 or num_heads <= 0:
  709. raise ValueError(
  710. f"embed_dim and num_heads must be greater than 0,"
  711. f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
  712. )
  713. factory_kwargs = {'device': device, 'dtype': dtype}
  714. super().__init__()
  715. self.embed_dim = embed_dim
  716. self.kdim = kdim if kdim is not None else embed_dim
  717. self.vdim = vdim if vdim is not None else embed_dim
  718. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  719. self.num_heads = num_heads
  720. self.dropout = dropout
  721. self.batch_first = batch_first
  722. self.head_dim = embed_dim // num_heads
  723. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  724. if not self._qkv_same_embed_dim:
  725. self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
  726. self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
  727. self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
  728. self.register_parameter('in_proj_weight', None)
  729. else:
  730. self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
  731. self.register_parameter('q_proj_weight', None)
  732. self.register_parameter('k_proj_weight', None)
  733. self.register_parameter('v_proj_weight', None)
  734. if bias:
  735. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  736. else:
  737. self.register_parameter('in_proj_bias', None)
  738. self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
  739. if add_bias_kv:
  740. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  741. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  742. else:
  743. self.bias_k = self.bias_v = None
  744. self.add_zero_attn = add_zero_attn
  745. self._reset_parameters()
  746. def _reset_parameters(self):
  747. if self._qkv_same_embed_dim:
  748. xavier_uniform_(self.in_proj_weight)
  749. else:
  750. xavier_uniform_(self.q_proj_weight)
  751. xavier_uniform_(self.k_proj_weight)
  752. xavier_uniform_(self.v_proj_weight)
  753. if self.in_proj_bias is not None:
  754. constant_(self.in_proj_bias, 0.)
  755. constant_(self.out_proj.bias, 0.)
  756. if self.bias_k is not None:
  757. xavier_normal_(self.bias_k)
  758. if self.bias_v is not None:
  759. xavier_normal_(self.bias_v)
  760. def __setstate__(self, state):
  761. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  762. if '_qkv_same_embed_dim' not in state:
  763. state['_qkv_same_embed_dim'] = True
  764. super().__setstate__(state)
  765. def forward(
  766. self,
  767. query: Tensor,
  768. key: Tensor,
  769. value: Tensor,
  770. key_padding_mask: Optional[Tensor] = None,
  771. need_weights: bool = True,
  772. attn_mask: Optional[Tensor] = None,
  773. average_attn_weights: bool = True,
  774. is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:
  775. r"""Compute attention outputs using query, key, and value embeddings.
  776. Supports optional parameters for padding, masks and attention weights.
  777. Args:
  778. query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
  779. or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
  780. :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
  781. Queries are compared against key-value pairs to produce the output.
  782. See "Attention Is All You Need" for more details.
  783. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
  784. or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
  785. :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
  786. See "Attention Is All You Need" for more details.
  787. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
  788. ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
  789. sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
  790. See "Attention Is All You Need" for more details.
  791. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
  792. to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
  793. Binary and float masks are supported.
  794. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
  795. the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
  796. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
  797. Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
  798. and achieve the best performance for MHA.
  799. Default: ``True``.
  800. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
  801. :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
  802. :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
  803. broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  804. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
  805. corresponding position is not allowed to attend. For a float mask, the mask values will be added to
  806. the attention weight.
  807. If both attn_mask and key_padding_mask are supplied, their types should match.
  808. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
  809. heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
  810. effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
  811. is_causal: If specified, applies a causal mask as attention mask.
  812. Default: ``False``.
  813. Warning:
  814. ``is_causal`` provides a hint that ``attn_mask`` is the
  815. causal mask. Providing incorrect hints can result in
  816. incorrect execution, including forward and backward
  817. compatibility.
  818. Outputs:
  819. - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
  820. :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
  821. where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
  822. embedding dimension ``embed_dim``.
  823. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
  824. returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
  825. :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
  826. :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
  827. head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
  828. .. note::
  829. `batch_first` argument is ignored for unbatched inputs.
  830. """
  831. why_not_fast_path = ''
  832. if ((attn_mask is not None and torch.is_floating_point(attn_mask))
  833. or (key_padding_mask is not None) and torch.is_floating_point(key_padding_mask)):
  834. why_not_fast_path = "floating-point masks are not supported for fast path."
  835. is_batched = query.dim() == 3
  836. key_padding_mask = F._canonical_mask(
  837. mask=key_padding_mask,
  838. mask_name="key_padding_mask",
  839. other_type=F._none_or_dtype(attn_mask),
  840. other_name="attn_mask",
  841. target_type=query.dtype
  842. )
  843. attn_mask = F._canonical_mask(
  844. mask=attn_mask,
  845. mask_name="attn_mask",
  846. other_type=None,
  847. other_name="",
  848. target_type=query.dtype,
  849. check_other=False,
  850. )
  851. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  852. if not is_fastpath_enabled:
  853. why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
  854. elif not is_batched:
  855. why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
  856. elif query is not key or key is not value:
  857. # When lifting this restriction, don't forget to either
  858. # enforce that the dtypes all match or test cases where
  859. # they don't!
  860. why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
  861. elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
  862. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
  863. elif self.in_proj_weight is None:
  864. why_not_fast_path = "in_proj_weight was None"
  865. elif query.dtype != self.in_proj_weight.dtype:
  866. # this case will fail anyway, but at least they'll get a useful error message.
  867. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
  868. elif self.training:
  869. why_not_fast_path = "training is enabled"
  870. elif (self.num_heads % 2) != 0:
  871. why_not_fast_path = "self.num_heads is not even"
  872. elif not self.batch_first:
  873. why_not_fast_path = "batch_first was not True"
  874. elif self.bias_k is not None:
  875. why_not_fast_path = "self.bias_k was not None"
  876. elif self.bias_v is not None:
  877. why_not_fast_path = "self.bias_v was not None"
  878. elif self.add_zero_attn:
  879. why_not_fast_path = "add_zero_attn was enabled"
  880. elif not self._qkv_same_embed_dim:
  881. why_not_fast_path = "_qkv_same_embed_dim was not True"
  882. elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
  883. why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
  884. is not supported with NestedTensor input"
  885. elif torch.is_autocast_enabled():
  886. why_not_fast_path = "autocast is enabled"
  887. if not why_not_fast_path:
  888. tensor_args = (
  889. query,
  890. key,
  891. value,
  892. self.in_proj_weight,
  893. self.in_proj_bias,
  894. self.out_proj.weight,
  895. self.out_proj.bias,
  896. )
  897. # We have to use list comprehensions below because TorchScript does not support
  898. # generator expressions.
  899. if torch.overrides.has_torch_function(tensor_args):
  900. why_not_fast_path = "some Tensor argument has_torch_function"
  901. elif _is_make_fx_tracing():
  902. why_not_fast_path = "we are running make_fx tracing"
  903. elif not all(_check_arg_device(x) for x in tensor_args):
  904. why_not_fast_path = ("some Tensor argument's device is neither one of "
  905. f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}")
  906. elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
  907. why_not_fast_path = ("grad is enabled and at least one of query or the "
  908. "input/output projection weights or biases requires_grad")
  909. if not why_not_fast_path:
  910. merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
  911. if self.in_proj_bias is not None and self.in_proj_weight is not None:
  912. return torch._native_multi_head_attention(
  913. query,
  914. key,
  915. value,
  916. self.embed_dim,
  917. self.num_heads,
  918. self.in_proj_weight,
  919. self.in_proj_bias,
  920. self.out_proj.weight,
  921. self.out_proj.bias,
  922. merged_mask,
  923. need_weights,
  924. average_attn_weights,
  925. mask_type)
  926. any_nested = query.is_nested or key.is_nested or value.is_nested
  927. assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
  928. f"The fast path was not hit because {why_not_fast_path}")
  929. if self.batch_first and is_batched:
  930. # make sure that the transpose op does not affect the "is" property
  931. if key is value:
  932. if query is key:
  933. query = key = value = query.transpose(1, 0)
  934. else:
  935. query, key = (x.transpose(1, 0) for x in (query, key))
  936. value = key
  937. else:
  938. query, key, value = (x.transpose(1, 0) for x in (query, key, value))
  939. if not self._qkv_same_embed_dim:
  940. attn_output, attn_output_weights = F.multi_head_attention_forward(
  941. query, key, value, self.embed_dim, self.num_heads,
  942. self.in_proj_weight, self.in_proj_bias,
  943. self.bias_k, self.bias_v, self.add_zero_attn,
  944. self.dropout, self.out_proj.weight, self.out_proj.bias,
  945. training=self.training,
  946. key_padding_mask=key_padding_mask, need_weights=need_weights,
  947. attn_mask=attn_mask,
  948. use_separate_proj_weight=True,
  949. q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
  950. v_proj_weight=self.v_proj_weight,
  951. average_attn_weights=average_attn_weights,
  952. is_causal=is_causal)
  953. else:
  954. attn_output, attn_output_weights = F.multi_head_attention_forward(
  955. query, key, value, self.embed_dim, self.num_heads,
  956. self.in_proj_weight, self.in_proj_bias,
  957. self.bias_k, self.bias_v, self.add_zero_attn,
  958. self.dropout, self.out_proj.weight, self.out_proj.bias,
  959. training=self.training,
  960. key_padding_mask=key_padding_mask,
  961. need_weights=need_weights,
  962. attn_mask=attn_mask,
  963. average_attn_weights=average_attn_weights,
  964. is_causal=is_causal)
  965. if self.batch_first and is_batched:
  966. return attn_output.transpose(1, 0), attn_output_weights
  967. else:
  968. return attn_output, attn_output_weights
  969. def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
  970. query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
  971. r"""Determine mask type and combine masks if necessary.
  972. If only one mask is provided, that mask
  973. and the corresponding mask type will be returned. If both masks are provided, they will be both
  974. expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
  975. and mask type 2 will be returned
  976. Args:
  977. attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
  978. key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
  979. query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
  980. Returns:
  981. merged_mask: merged mask
  982. mask_type: merged mask type (0, 1, or 2)
  983. """
  984. mask_type: Optional[int] = None
  985. merged_mask: Optional[Tensor] = None
  986. if key_padding_mask is not None:
  987. mask_type = 1
  988. merged_mask = key_padding_mask
  989. if attn_mask is not None:
  990. # In this branch query can't be a nested tensor, so it has a shape
  991. batch_size, seq_len, _ = query.shape
  992. mask_type = 2
  993. # Always expands attn_mask to 4D
  994. if attn_mask.dim() == 3:
  995. attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
  996. else: # attn_mask.dim() == 2:
  997. attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
  998. merged_mask = attn_mask_expanded
  999. if key_padding_mask is not None:
  1000. key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
  1001. merged_mask = attn_mask_expanded + key_padding_mask_expanded
  1002. # no attn_mask and no key_padding_mask, returns None, None
  1003. return merged_mask, mask_type
  1004. class PReLU(Module):
  1005. r"""Applies the element-wise PReLU function.
  1006. .. math::
  1007. \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
  1008. or
  1009. .. math::
  1010. \text{PReLU}(x) =
  1011. \begin{cases}
  1012. x, & \text{ if } x \ge 0 \\
  1013. ax, & \text{ otherwise }
  1014. \end{cases}
  1015. Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
  1016. parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
  1017. a separate :math:`a` is used for each input channel.
  1018. .. note::
  1019. weight decay should not be used when learning :math:`a` for good performance.
  1020. .. note::
  1021. Channel dim is the 2nd dim of input. When input has dims < 2, then there is
  1022. no channel dim and the number of channels = 1.
  1023. Args:
  1024. num_parameters (int): number of :math:`a` to learn.
  1025. Although it takes an int as input, there is only two values are legitimate:
  1026. 1, or the number of channels at input. Default: 1
  1027. init (float): the initial value of :math:`a`. Default: 0.25
  1028. Shape:
  1029. - Input: :math:`( *)` where `*` means, any number of additional
  1030. dimensions.
  1031. - Output: :math:`(*)`, same shape as the input.
  1032. Attributes:
  1033. weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
  1034. .. image:: ../scripts/activation_images/PReLU.png
  1035. Examples::
  1036. >>> m = nn.PReLU()
  1037. >>> input = torch.randn(2)
  1038. >>> output = m(input)
  1039. """
  1040. __constants__ = ['num_parameters']
  1041. num_parameters: int
  1042. def __init__(self, num_parameters: int = 1, init: float = 0.25,
  1043. device=None, dtype=None) -> None:
  1044. factory_kwargs = {'device': device, 'dtype': dtype}
  1045. self.num_parameters = num_parameters
  1046. super().__init__()
  1047. self.init = init
  1048. self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
  1049. self.reset_parameters()
  1050. def reset_parameters(self):
  1051. torch.nn.init.constant_(self.weight, self.init)
  1052. def forward(self, input: Tensor) -> Tensor:
  1053. return F.prelu(input, self.weight)
  1054. def extra_repr(self) -> str:
  1055. return f'num_parameters={self.num_parameters}'
  1056. class Softsign(Module):
  1057. r"""Applies the element-wise Softsign function.
  1058. .. math::
  1059. \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
  1060. Shape:
  1061. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1062. - Output: :math:`(*)`, same shape as the input.
  1063. .. image:: ../scripts/activation_images/Softsign.png
  1064. Examples::
  1065. >>> m = nn.Softsign()
  1066. >>> input = torch.randn(2)
  1067. >>> output = m(input)
  1068. """
  1069. def forward(self, input: Tensor) -> Tensor:
  1070. return F.softsign(input)
  1071. class Tanhshrink(Module):
  1072. r"""Applies the element-wise Tanhshrink function.
  1073. .. math::
  1074. \text{Tanhshrink}(x) = x - \tanh(x)
  1075. Shape:
  1076. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1077. - Output: :math:`(*)`, same shape as the input.
  1078. .. image:: ../scripts/activation_images/Tanhshrink.png
  1079. Examples::
  1080. >>> m = nn.Tanhshrink()
  1081. >>> input = torch.randn(2)
  1082. >>> output = m(input)
  1083. """
  1084. def forward(self, input: Tensor) -> Tensor:
  1085. return F.tanhshrink(input)
  1086. class Softmin(Module):
  1087. r"""Applies the Softmin function to an n-dimensional input Tensor.
  1088. Rescales them so that the elements of the n-dimensional output Tensor
  1089. lie in the range `[0, 1]` and sum to 1.
  1090. Softmin is defined as:
  1091. .. math::
  1092. \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
  1093. Shape:
  1094. - Input: :math:`(*)` where `*` means, any number of additional
  1095. dimensions
  1096. - Output: :math:`(*)`, same shape as the input
  1097. Args:
  1098. dim (int): A dimension along which Softmin will be computed (so every slice
  1099. along dim will sum to 1).
  1100. Returns:
  1101. a Tensor of the same dimension and shape as the input, with
  1102. values in the range [0, 1]
  1103. Examples::
  1104. >>> m = nn.Softmin(dim=1)
  1105. >>> input = torch.randn(2, 3)
  1106. >>> output = m(input)
  1107. """
  1108. __constants__ = ['dim']
  1109. dim: Optional[int]
  1110. def __init__(self, dim: Optional[int] = None) -> None:
  1111. super().__init__()
  1112. self.dim = dim
  1113. def __setstate__(self, state):
  1114. super().__setstate__(state)
  1115. if not hasattr(self, 'dim'):
  1116. self.dim = None
  1117. def forward(self, input: Tensor) -> Tensor:
  1118. return F.softmin(input, self.dim, _stacklevel=5)
  1119. def extra_repr(self):
  1120. return f'dim={self.dim}'
  1121. class Softmax(Module):
  1122. r"""Applies the Softmax function to an n-dimensional input Tensor.
  1123. Rescales them so that the elements of the n-dimensional output Tensor
  1124. lie in the range [0,1] and sum to 1.
  1125. Softmax is defined as:
  1126. .. math::
  1127. \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  1128. When the input Tensor is a sparse tensor then the unspecified
  1129. values are treated as ``-inf``.
  1130. Shape:
  1131. - Input: :math:`(*)` where `*` means, any number of additional
  1132. dimensions
  1133. - Output: :math:`(*)`, same shape as the input
  1134. Returns:
  1135. a Tensor of the same dimension and shape as the input with
  1136. values in the range [0, 1]
  1137. Args:
  1138. dim (int): A dimension along which Softmax will be computed (so every slice
  1139. along dim will sum to 1).
  1140. .. note::
  1141. This module doesn't work directly with NLLLoss,
  1142. which expects the Log to be computed between the Softmax and itself.
  1143. Use `LogSoftmax` instead (it's faster and has better numerical properties).
  1144. Examples::
  1145. >>> m = nn.Softmax(dim=1)
  1146. >>> input = torch.randn(2, 3)
  1147. >>> output = m(input)
  1148. """
  1149. __constants__ = ['dim']
  1150. dim: Optional[int]
  1151. def __init__(self, dim: Optional[int] = None) -> None:
  1152. super().__init__()
  1153. self.dim = dim
  1154. def __setstate__(self, state):
  1155. super().__setstate__(state)
  1156. if not hasattr(self, 'dim'):
  1157. self.dim = None
  1158. def forward(self, input: Tensor) -> Tensor:
  1159. return F.softmax(input, self.dim, _stacklevel=5)
  1160. def extra_repr(self) -> str:
  1161. return f'dim={self.dim}'
  1162. class Softmax2d(Module):
  1163. r"""Applies SoftMax over features to each spatial location.
  1164. When given an image of ``Channels x Height x Width``, it will
  1165. apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
  1166. Shape:
  1167. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
  1168. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  1169. Returns:
  1170. a Tensor of the same dimension and shape as the input with
  1171. values in the range [0, 1]
  1172. Examples::
  1173. >>> m = nn.Softmax2d()
  1174. >>> # you softmax over the 2nd dimension
  1175. >>> input = torch.randn(2, 3, 12, 13)
  1176. >>> output = m(input)
  1177. """
  1178. def forward(self, input: Tensor) -> Tensor:
  1179. if input.dim() not in (3, 4):
  1180. raise ValueError(
  1181. f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
  1182. )
  1183. return F.softmax(input, -3, _stacklevel=5)
  1184. class LogSoftmax(Module):
  1185. r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
  1186. The LogSoftmax formulation can be simplified as:
  1187. .. math::
  1188. \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
  1189. Shape:
  1190. - Input: :math:`(*)` where `*` means, any number of additional
  1191. dimensions
  1192. - Output: :math:`(*)`, same shape as the input
  1193. Args:
  1194. dim (int): A dimension along which LogSoftmax will be computed.
  1195. Returns:
  1196. a Tensor of the same dimension and shape as the input with
  1197. values in the range [-inf, 0)
  1198. Examples::
  1199. >>> m = nn.LogSoftmax(dim=1)
  1200. >>> input = torch.randn(2, 3)
  1201. >>> output = m(input)
  1202. """
  1203. __constants__ = ['dim']
  1204. dim: Optional[int]
  1205. def __init__(self, dim: Optional[int] = None) -> None:
  1206. super().__init__()
  1207. self.dim = dim
  1208. def __setstate__(self, state):
  1209. super().__setstate__(state)
  1210. if not hasattr(self, 'dim'):
  1211. self.dim = None
  1212. def forward(self, input: Tensor) -> Tensor:
  1213. return F.log_softmax(input, self.dim, _stacklevel=5)
  1214. def extra_repr(self):
  1215. return f'dim={self.dim}'