transformer.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. # mypy: allow-untyped-defs
  2. import copy
  3. from typing import Optional, Any, Union, Callable
  4. import torch
  5. import warnings
  6. from torch import Tensor
  7. from .. import functional as F
  8. from .module import Module
  9. from .activation import MultiheadAttention
  10. from .container import ModuleList
  11. from ..init import xavier_uniform_
  12. from .dropout import Dropout
  13. from .linear import Linear
  14. from .normalization import LayerNorm
  15. __all__ = ['Transformer', 'TransformerEncoder', 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer']
  16. def _generate_square_subsequent_mask(
  17. sz: int,
  18. device: Optional[torch.device] = None,
  19. dtype: Optional[torch.dtype] = None,
  20. ) -> Tensor:
  21. r"""Generate a square causal mask for the sequence.
  22. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
  23. """
  24. if device is None:
  25. device = torch.device('cpu')
  26. if dtype is None:
  27. dtype = torch.float32
  28. return torch.triu(
  29. torch.full((sz, sz), float('-inf'), dtype=dtype, device=device),
  30. diagonal=1,
  31. )
  32. def _get_seq_len(
  33. src: Tensor,
  34. batch_first: bool
  35. ) -> Optional[int]:
  36. if src.is_nested:
  37. return None
  38. else:
  39. src_size = src.size()
  40. if len(src_size) == 2:
  41. # unbatched: S, E
  42. return src_size[0]
  43. else:
  44. # batched: B, S, E if batch_first else S, B, E
  45. seq_len_pos = 1 if batch_first else 0
  46. return src_size[seq_len_pos]
  47. class Transformer(Module):
  48. r"""A transformer model.
  49. User is able to modify the attributes as needed. The architecture
  50. is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
  51. Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
  52. Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
  53. Processing Systems, pages 6000-6010.
  54. Args:
  55. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  56. nhead: the number of heads in the multiheadattention models (default=8).
  57. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  58. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  59. dim_feedforward: the dimension of the feedforward network model (default=2048).
  60. dropout: the dropout value (default=0.1).
  61. activation: the activation function of encoder/decoder intermediate layer, can be a string
  62. ("relu" or "gelu") or a unary callable. Default: relu
  63. custom_encoder: custom encoder (default=None).
  64. custom_decoder: custom decoder (default=None).
  65. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  66. batch_first: If ``True``, then the input and output tensors are provided
  67. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  68. norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
  69. other attention and feedforward operations, otherwise after. Default: ``False`` (after).
  70. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  71. bias. Default: ``True``.
  72. Examples::
  73. >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
  74. >>> src = torch.rand((10, 32, 512))
  75. >>> tgt = torch.rand((20, 32, 512))
  76. >>> out = transformer_model(src, tgt)
  77. Note: A full example to apply nn.Transformer module for the word language model is available in
  78. https://github.com/pytorch/examples/tree/master/word_language_model
  79. """
  80. def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
  81. num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
  82. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  83. custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None,
  84. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  85. bias: bool = True, device=None, dtype=None) -> None:
  86. factory_kwargs = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  89. if custom_encoder is not None:
  90. self.encoder = custom_encoder
  91. else:
  92. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
  93. activation, layer_norm_eps, batch_first, norm_first,
  94. bias, **factory_kwargs)
  95. encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  96. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
  97. if custom_decoder is not None:
  98. self.decoder = custom_decoder
  99. else:
  100. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
  101. activation, layer_norm_eps, batch_first, norm_first,
  102. bias, **factory_kwargs)
  103. decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  104. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
  105. self._reset_parameters()
  106. self.d_model = d_model
  107. self.nhead = nhead
  108. self.batch_first = batch_first
  109. def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
  110. memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
  111. tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
  112. src_is_causal: Optional[bool] = None, tgt_is_causal: Optional[bool] = None,
  113. memory_is_causal: bool = False) -> Tensor:
  114. r"""Take in and process masked source/target sequences.
  115. .. note::
  116. If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
  117. not allowed to participate in the attention,
  118. which is the opposite of the definition for :attr:`attn_mask`
  119. in :func:`torch.nn.functional.scaled_dot_product_attention`.
  120. Args:
  121. src: the sequence to the encoder (required).
  122. tgt: the sequence to the decoder (required).
  123. src_mask: the additive mask for the src sequence (optional).
  124. tgt_mask: the additive mask for the tgt sequence (optional).
  125. memory_mask: the additive mask for the encoder output (optional).
  126. src_key_padding_mask: the Tensor mask for src keys per batch (optional).
  127. tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
  128. memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
  129. src_is_causal: If specified, applies a causal mask as ``src_mask``.
  130. Default: ``None``; try to detect a causal mask.
  131. Warning:
  132. ``src_is_causal`` provides a hint that ``src_mask`` is
  133. the causal mask. Providing incorrect hints can result in
  134. incorrect execution, including forward and backward
  135. compatibility.
  136. tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
  137. Default: ``None``; try to detect a causal mask.
  138. Warning:
  139. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  140. the causal mask. Providing incorrect hints can result in
  141. incorrect execution, including forward and backward
  142. compatibility.
  143. memory_is_causal: If specified, applies a causal mask as
  144. ``memory_mask``.
  145. Default: ``False``.
  146. Warning:
  147. ``memory_is_causal`` provides a hint that
  148. ``memory_mask`` is the causal mask. Providing incorrect
  149. hints can result in incorrect execution, including
  150. forward and backward compatibility.
  151. Shape:
  152. - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
  153. `(N, S, E)` if `batch_first=True`.
  154. - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  155. `(N, T, E)` if `batch_first=True`.
  156. - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
  157. - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
  158. - memory_mask: :math:`(T, S)`.
  159. - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  160. - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
  161. - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  162. Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
  163. positions. If a BoolTensor is provided, positions with ``True``
  164. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  165. is provided, it will be added to the attention weight.
  166. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
  167. the attention. If a BoolTensor is provided, the positions with the
  168. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  169. - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  170. `(N, T, E)` if `batch_first=True`.
  171. Note: Due to the multi-head attention architecture in the transformer model,
  172. the output sequence length of a transformer is same as the input sequence
  173. (i.e. target) length of the decoder.
  174. where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
  175. batch size, :math:`E` is the feature number
  176. Examples:
  177. >>> # xdoctest: +SKIP
  178. >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
  179. """
  180. is_batched = src.dim() == 3
  181. if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
  182. raise RuntimeError("the batch number of src and tgt must be equal")
  183. elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
  184. raise RuntimeError("the batch number of src and tgt must be equal")
  185. if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
  186. raise RuntimeError("the feature number of src and tgt must be equal to d_model")
  187. memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
  188. is_causal=src_is_causal)
  189. output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
  190. tgt_key_padding_mask=tgt_key_padding_mask,
  191. memory_key_padding_mask=memory_key_padding_mask,
  192. tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
  193. return output
  194. @staticmethod
  195. def generate_square_subsequent_mask(
  196. sz: int,
  197. device: Optional[torch.device] = None,
  198. dtype: Optional[torch.dtype] = None,
  199. ) -> Tensor:
  200. r"""Generate a square causal mask for the sequence.
  201. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
  202. """
  203. return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
  204. def _reset_parameters(self):
  205. r"""Initiate parameters in the transformer model."""
  206. for p in self.parameters():
  207. if p.dim() > 1:
  208. xavier_uniform_(p)
  209. class TransformerEncoder(Module):
  210. r"""TransformerEncoder is a stack of N encoder layers.
  211. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
  212. Args:
  213. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
  214. num_layers: the number of sub-encoder-layers in the encoder (required).
  215. norm: the layer normalization component (optional).
  216. enable_nested_tensor: if True, input will automatically convert to nested tensor
  217. (and convert back on output). This will improve the overall performance of
  218. TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
  219. Examples::
  220. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  221. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  222. >>> src = torch.rand(10, 32, 512)
  223. >>> out = transformer_encoder(src)
  224. """
  225. __constants__ = ['norm']
  226. def __init__(
  227. self,
  228. encoder_layer: "TransformerEncoderLayer",
  229. num_layers: int,
  230. norm: Optional[Module] = None,
  231. enable_nested_tensor: bool = True,
  232. mask_check: bool = True
  233. ) -> None:
  234. super().__init__()
  235. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  236. self.layers = _get_clones(encoder_layer, num_layers)
  237. self.num_layers = num_layers
  238. self.norm = norm
  239. # this attribute saves the value providedat object construction
  240. self.enable_nested_tensor = enable_nested_tensor
  241. # this attribute controls whether nested tensors are used
  242. self.use_nested_tensor = enable_nested_tensor
  243. self.mask_check = mask_check
  244. enc_layer = "encoder_layer"
  245. why_not_sparsity_fast_path = ''
  246. if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
  247. why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
  248. elif encoder_layer.norm_first :
  249. why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
  250. elif not encoder_layer.self_attn.batch_first:
  251. why_not_sparsity_fast_path = (f"{enc_layer}.self_attn.batch_first was not True" +
  252. "(use batch_first for better inference performance)")
  253. elif not encoder_layer.self_attn._qkv_same_embed_dim:
  254. why_not_sparsity_fast_path = f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
  255. elif encoder_layer.self_attn.in_proj_bias is None:
  256. why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
  257. elif not encoder_layer.activation_relu_or_gelu:
  258. why_not_sparsity_fast_path = f"{enc_layer}.activation_relu_or_gelu was not True"
  259. elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps) :
  260. why_not_sparsity_fast_path = f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
  261. elif encoder_layer.self_attn.num_heads % 2 == 1:
  262. why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
  263. if enable_nested_tensor and why_not_sparsity_fast_path:
  264. warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
  265. self.use_nested_tensor = False
  266. def forward(
  267. self,
  268. src: Tensor,
  269. mask: Optional[Tensor] = None,
  270. src_key_padding_mask: Optional[Tensor] = None,
  271. is_causal: Optional[bool] = None) -> Tensor:
  272. r"""Pass the input through the encoder layers in turn.
  273. Args:
  274. src: the sequence to the encoder (required).
  275. mask: the mask for the src sequence (optional).
  276. src_key_padding_mask: the mask for the src keys per batch (optional).
  277. is_causal: If specified, applies a causal mask as ``mask``.
  278. Default: ``None``; try to detect a causal mask.
  279. Warning:
  280. ``is_causal`` provides a hint that ``mask`` is the
  281. causal mask. Providing incorrect hints can result in
  282. incorrect execution, including forward and backward
  283. compatibility.
  284. Shape:
  285. see the docs in :class:`~torch.nn.Transformer`.
  286. """
  287. src_key_padding_mask = F._canonical_mask(
  288. mask=src_key_padding_mask,
  289. mask_name="src_key_padding_mask",
  290. other_type=F._none_or_dtype(mask),
  291. other_name="mask",
  292. target_type=src.dtype
  293. )
  294. mask = F._canonical_mask(
  295. mask=mask,
  296. mask_name="mask",
  297. other_type=None,
  298. other_name="",
  299. target_type=src.dtype,
  300. check_other=False,
  301. )
  302. output = src
  303. convert_to_nested = False
  304. first_layer = self.layers[0]
  305. src_key_padding_mask_for_layers = src_key_padding_mask
  306. why_not_sparsity_fast_path = ''
  307. str_first_layer = "self.layers[0]"
  308. batch_first = first_layer.self_attn.batch_first
  309. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  310. if not is_fastpath_enabled:
  311. why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
  312. elif not hasattr(self, "use_nested_tensor"):
  313. why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
  314. elif not self.use_nested_tensor:
  315. why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
  316. elif first_layer.training:
  317. why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
  318. elif not src.dim() == 3:
  319. why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  320. elif src_key_padding_mask is None:
  321. why_not_sparsity_fast_path = "src_key_padding_mask was None"
  322. elif (((not hasattr(self, "mask_check")) or self.mask_check)
  323. and not torch._nested_tensor_from_mask_left_aligned(src, src_key_padding_mask.logical_not())):
  324. why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
  325. elif output.is_nested:
  326. why_not_sparsity_fast_path = "NestedTensor input is not supported"
  327. elif mask is not None:
  328. why_not_sparsity_fast_path = "src_key_padding_mask and mask were both supplied"
  329. elif torch.is_autocast_enabled():
  330. why_not_sparsity_fast_path = "autocast is enabled"
  331. if not why_not_sparsity_fast_path:
  332. tensor_args = (
  333. src,
  334. first_layer.self_attn.in_proj_weight,
  335. first_layer.self_attn.in_proj_bias,
  336. first_layer.self_attn.out_proj.weight,
  337. first_layer.self_attn.out_proj.bias,
  338. first_layer.norm1.weight,
  339. first_layer.norm1.bias,
  340. first_layer.norm2.weight,
  341. first_layer.norm2.bias,
  342. first_layer.linear1.weight,
  343. first_layer.linear1.bias,
  344. first_layer.linear2.weight,
  345. first_layer.linear2.bias,
  346. )
  347. _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
  348. if torch.overrides.has_torch_function(tensor_args):
  349. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  350. elif src.device.type not in _supported_device_type:
  351. why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
  352. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  353. why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
  354. "input/output projection weights or biases requires_grad")
  355. if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
  356. convert_to_nested = True
  357. output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
  358. src_key_padding_mask_for_layers = None
  359. seq_len = _get_seq_len(src, batch_first)
  360. is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
  361. for mod in self.layers:
  362. output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
  363. if convert_to_nested:
  364. output = output.to_padded_tensor(0., src.size())
  365. if self.norm is not None:
  366. output = self.norm(output)
  367. return output
  368. class TransformerDecoder(Module):
  369. r"""TransformerDecoder is a stack of N decoder layers.
  370. Args:
  371. decoder_layer: an instance of the TransformerDecoderLayer() class (required).
  372. num_layers: the number of sub-decoder-layers in the decoder (required).
  373. norm: the layer normalization component (optional).
  374. Examples::
  375. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  376. >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
  377. >>> memory = torch.rand(10, 32, 512)
  378. >>> tgt = torch.rand(20, 32, 512)
  379. >>> out = transformer_decoder(tgt, memory)
  380. """
  381. __constants__ = ['norm']
  382. def __init__(
  383. self,
  384. decoder_layer: "TransformerDecoderLayer",
  385. num_layers: int,
  386. norm: Optional[Module] = None
  387. ) -> None:
  388. super().__init__()
  389. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  390. self.layers = _get_clones(decoder_layer, num_layers)
  391. self.num_layers = num_layers
  392. self.norm = norm
  393. def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
  394. memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
  395. memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None,
  396. memory_is_causal: bool = False) -> Tensor:
  397. r"""Pass the inputs (and mask) through the decoder layer in turn.
  398. Args:
  399. tgt: the sequence to the decoder (required).
  400. memory: the sequence from the last layer of the encoder (required).
  401. tgt_mask: the mask for the tgt sequence (optional).
  402. memory_mask: the mask for the memory sequence (optional).
  403. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  404. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  405. tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
  406. Default: ``None``; try to detect a causal mask.
  407. Warning:
  408. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  409. the causal mask. Providing incorrect hints can result in
  410. incorrect execution, including forward and backward
  411. compatibility.
  412. memory_is_causal: If specified, applies a causal mask as
  413. ``memory mask``.
  414. Default: ``False``.
  415. Warning:
  416. ``memory_is_causal`` provides a hint that
  417. ``memory_mask`` is the causal mask. Providing incorrect
  418. hints can result in incorrect execution, including
  419. forward and backward compatibility.
  420. Shape:
  421. see the docs in :class:`~torch.nn.Transformer`.
  422. """
  423. output = tgt
  424. seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
  425. tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
  426. for mod in self.layers:
  427. output = mod(output, memory, tgt_mask=tgt_mask,
  428. memory_mask=memory_mask,
  429. tgt_key_padding_mask=tgt_key_padding_mask,
  430. memory_key_padding_mask=memory_key_padding_mask,
  431. tgt_is_causal=tgt_is_causal,
  432. memory_is_causal=memory_is_causal)
  433. if self.norm is not None:
  434. output = self.norm(output)
  435. return output
  436. class TransformerEncoderLayer(Module):
  437. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  438. This standard encoder layer is based on the paper "Attention Is All You Need".
  439. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  440. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  441. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  442. in a different way during application.
  443. TransformerEncoderLayer can handle either traditional torch.tensor inputs,
  444. or Nested Tensor inputs. Derived classes are expected to similarly accept
  445. both input formats. (Not all combinations of inputs are currently
  446. supported by TransformerEncoderLayer while Nested Tensor is in prototype
  447. state.)
  448. If you are implementing a custom layer, you may derive it either from
  449. the Module or TransformerEncoderLayer class. If your custom layer
  450. supports both torch.Tensors and Nested Tensors inputs, make its
  451. implementation a derived class of TransformerEncoderLayer. If your custom
  452. Layer supports only torch.Tensor inputs, derive its implementation from
  453. Module.
  454. Args:
  455. d_model: the number of expected features in the input (required).
  456. nhead: the number of heads in the multiheadattention models (required).
  457. dim_feedforward: the dimension of the feedforward network model (default=2048).
  458. dropout: the dropout value (default=0.1).
  459. activation: the activation function of the intermediate layer, can be a string
  460. ("relu" or "gelu") or a unary callable. Default: relu
  461. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  462. batch_first: If ``True``, then the input and output tensors are provided
  463. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  464. norm_first: if ``True``, layer norm is done prior to attention and feedforward
  465. operations, respectively. Otherwise it's done after. Default: ``False`` (after).
  466. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  467. bias. Default: ``True``.
  468. Examples::
  469. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  470. >>> src = torch.rand(10, 32, 512)
  471. >>> out = encoder_layer(src)
  472. Alternatively, when ``batch_first`` is ``True``:
  473. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
  474. >>> src = torch.rand(32, 10, 512)
  475. >>> out = encoder_layer(src)
  476. Fast path:
  477. forward() will use a special optimized implementation described in
  478. `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
  479. conditions are met:
  480. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
  481. argument ``requires_grad``
  482. - training is disabled (using ``.eval()``)
  483. - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
  484. - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
  485. - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
  486. - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
  487. nor ``src_key_padding_mask`` is passed
  488. - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
  489. unless the caller has manually modified one without modifying the other)
  490. If the optimized implementation is in use, a
  491. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
  492. passed for ``src`` to represent padding more efficiently than using a padding
  493. mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
  494. returned, and an additional speedup proportional to the fraction of the input that
  495. is padding can be expected.
  496. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  497. https://arxiv.org/abs/2205.14135
  498. """
  499. __constants__ = ['norm_first']
  500. def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
  501. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  502. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  503. bias: bool = True, device=None, dtype=None) -> None:
  504. factory_kwargs = {'device': device, 'dtype': dtype}
  505. super().__init__()
  506. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout,
  507. bias=bias, batch_first=batch_first,
  508. **factory_kwargs)
  509. # Implementation of Feedforward model
  510. self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
  511. self.dropout = Dropout(dropout)
  512. self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
  513. self.norm_first = norm_first
  514. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  515. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  516. self.dropout1 = Dropout(dropout)
  517. self.dropout2 = Dropout(dropout)
  518. # Legacy string support for activation function.
  519. if isinstance(activation, str):
  520. activation = _get_activation_fn(activation)
  521. # We can't test self.activation in forward() in TorchScript,
  522. # so stash some information about it instead.
  523. if activation is F.relu or isinstance(activation, torch.nn.ReLU):
  524. self.activation_relu_or_gelu = 1
  525. elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
  526. self.activation_relu_or_gelu = 2
  527. else:
  528. self.activation_relu_or_gelu = 0
  529. self.activation = activation
  530. def __setstate__(self, state):
  531. super().__setstate__(state)
  532. if not hasattr(self, 'activation'):
  533. self.activation = F.relu
  534. def forward(
  535. self,
  536. src: Tensor,
  537. src_mask: Optional[Tensor] = None,
  538. src_key_padding_mask: Optional[Tensor] = None,
  539. is_causal: bool = False) -> Tensor:
  540. r"""Pass the input through the encoder layer.
  541. Args:
  542. src: the sequence to the encoder layer (required).
  543. src_mask: the mask for the src sequence (optional).
  544. src_key_padding_mask: the mask for the src keys per batch (optional).
  545. is_causal: If specified, applies a causal mask as ``src mask``.
  546. Default: ``False``.
  547. Warning:
  548. ``is_causal`` provides a hint that ``src_mask`` is the
  549. causal mask. Providing incorrect hints can result in
  550. incorrect execution, including forward and backward
  551. compatibility.
  552. Shape:
  553. see the docs in :class:`~torch.nn.Transformer`.
  554. """
  555. src_key_padding_mask = F._canonical_mask(
  556. mask=src_key_padding_mask,
  557. mask_name="src_key_padding_mask",
  558. other_type=F._none_or_dtype(src_mask),
  559. other_name="src_mask",
  560. target_type=src.dtype
  561. )
  562. src_mask = F._canonical_mask(
  563. mask=src_mask,
  564. mask_name="src_mask",
  565. other_type=None,
  566. other_name="",
  567. target_type=src.dtype,
  568. check_other=False,
  569. )
  570. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  571. why_not_sparsity_fast_path = ''
  572. if not is_fastpath_enabled:
  573. why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
  574. elif not src.dim() == 3:
  575. why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  576. elif self.training:
  577. why_not_sparsity_fast_path = "training is enabled"
  578. elif not self.self_attn.batch_first:
  579. why_not_sparsity_fast_path = "self_attn.batch_first was not True"
  580. elif self.self_attn.in_proj_bias is None:
  581. why_not_sparsity_fast_path = "self_attn was passed bias=False"
  582. elif not self.self_attn._qkv_same_embed_dim:
  583. why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
  584. elif not self.activation_relu_or_gelu:
  585. why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
  586. elif not (self.norm1.eps == self.norm2.eps):
  587. why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
  588. elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
  589. why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
  590. elif self.self_attn.num_heads % 2 == 1:
  591. why_not_sparsity_fast_path = "num_head is odd"
  592. elif torch.is_autocast_enabled():
  593. why_not_sparsity_fast_path = "autocast is enabled"
  594. if not why_not_sparsity_fast_path:
  595. tensor_args = (
  596. src,
  597. self.self_attn.in_proj_weight,
  598. self.self_attn.in_proj_bias,
  599. self.self_attn.out_proj.weight,
  600. self.self_attn.out_proj.bias,
  601. self.norm1.weight,
  602. self.norm1.bias,
  603. self.norm2.weight,
  604. self.norm2.bias,
  605. self.linear1.weight,
  606. self.linear1.bias,
  607. self.linear2.weight,
  608. self.linear2.bias,
  609. )
  610. # We have to use list comprehensions below because TorchScript does not support
  611. # generator expressions.
  612. _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
  613. if torch.overrides.has_torch_function(tensor_args):
  614. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  615. elif not all((x.device.type in _supported_device_type) for x in tensor_args):
  616. why_not_sparsity_fast_path = ("some Tensor argument's device is neither one of "
  617. f"{_supported_device_type}")
  618. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  619. why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
  620. "input/output projection weights or biases requires_grad")
  621. if not why_not_sparsity_fast_path:
  622. merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)
  623. return torch._transformer_encoder_layer_fwd(
  624. src,
  625. self.self_attn.embed_dim,
  626. self.self_attn.num_heads,
  627. self.self_attn.in_proj_weight,
  628. self.self_attn.in_proj_bias,
  629. self.self_attn.out_proj.weight,
  630. self.self_attn.out_proj.bias,
  631. self.activation_relu_or_gelu == 2,
  632. self.norm_first,
  633. self.norm1.eps,
  634. self.norm1.weight,
  635. self.norm1.bias,
  636. self.norm2.weight,
  637. self.norm2.bias,
  638. self.linear1.weight,
  639. self.linear1.bias,
  640. self.linear2.weight,
  641. self.linear2.bias,
  642. merged_mask,
  643. mask_type,
  644. )
  645. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  646. x = src
  647. if self.norm_first:
  648. x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
  649. x = x + self._ff_block(self.norm2(x))
  650. else:
  651. x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
  652. x = self.norm2(x + self._ff_block(x))
  653. return x
  654. # self-attention block
  655. def _sa_block(self, x: Tensor,
  656. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
  657. x = self.self_attn(x, x, x,
  658. attn_mask=attn_mask,
  659. key_padding_mask=key_padding_mask,
  660. need_weights=False, is_causal=is_causal)[0]
  661. return self.dropout1(x)
  662. # feed forward block
  663. def _ff_block(self, x: Tensor) -> Tensor:
  664. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  665. return self.dropout2(x)
  666. class TransformerDecoderLayer(Module):
  667. r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
  668. This standard decoder layer is based on the paper "Attention Is All You Need".
  669. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
  670. Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
  671. Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
  672. in a different way during application.
  673. Args:
  674. d_model: the number of expected features in the input (required).
  675. nhead: the number of heads in the multiheadattention models (required).
  676. dim_feedforward: the dimension of the feedforward network model (default=2048).
  677. dropout: the dropout value (default=0.1).
  678. activation: the activation function of the intermediate layer, can be a string
  679. ("relu" or "gelu") or a unary callable. Default: relu
  680. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  681. batch_first: If ``True``, then the input and output tensors are provided
  682. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  683. norm_first: if ``True``, layer norm is done prior to self attention, multihead
  684. attention and feedforward operations, respectively. Otherwise it's done after.
  685. Default: ``False`` (after).
  686. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  687. bias. Default: ``True``.
  688. Examples::
  689. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  690. >>> memory = torch.rand(10, 32, 512)
  691. >>> tgt = torch.rand(20, 32, 512)
  692. >>> out = decoder_layer(tgt, memory)
  693. Alternatively, when ``batch_first`` is ``True``:
  694. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
  695. >>> memory = torch.rand(32, 10, 512)
  696. >>> tgt = torch.rand(32, 20, 512)
  697. >>> out = decoder_layer(tgt, memory)
  698. """
  699. __constants__ = ['norm_first']
  700. def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
  701. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  702. layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
  703. bias: bool = True, device=None, dtype=None) -> None:
  704. factory_kwargs = {'device': device, 'dtype': dtype}
  705. super().__init__()
  706. self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  707. bias=bias, **factory_kwargs)
  708. self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
  709. bias=bias, **factory_kwargs)
  710. # Implementation of Feedforward model
  711. self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
  712. self.dropout = Dropout(dropout)
  713. self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
  714. self.norm_first = norm_first
  715. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  716. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  717. self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  718. self.dropout1 = Dropout(dropout)
  719. self.dropout2 = Dropout(dropout)
  720. self.dropout3 = Dropout(dropout)
  721. # Legacy string support for activation function.
  722. if isinstance(activation, str):
  723. self.activation = _get_activation_fn(activation)
  724. else:
  725. self.activation = activation
  726. def __setstate__(self, state):
  727. if 'activation' not in state:
  728. state['activation'] = F.relu
  729. super().__setstate__(state)
  730. def forward(
  731. self,
  732. tgt: Tensor,
  733. memory: Tensor,
  734. tgt_mask: Optional[Tensor] = None,
  735. memory_mask: Optional[Tensor] = None,
  736. tgt_key_padding_mask: Optional[Tensor] = None,
  737. memory_key_padding_mask: Optional[Tensor] = None,
  738. tgt_is_causal: bool = False,
  739. memory_is_causal: bool = False,
  740. ) -> Tensor:
  741. r"""Pass the inputs (and mask) through the decoder layer.
  742. Args:
  743. tgt: the sequence to the decoder layer (required).
  744. memory: the sequence from the last layer of the encoder (required).
  745. tgt_mask: the mask for the tgt sequence (optional).
  746. memory_mask: the mask for the memory sequence (optional).
  747. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  748. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  749. tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
  750. Default: ``False``.
  751. Warning:
  752. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  753. the causal mask. Providing incorrect hints can result in
  754. incorrect execution, including forward and backward
  755. compatibility.
  756. memory_is_causal: If specified, applies a causal mask as
  757. ``memory mask``.
  758. Default: ``False``.
  759. Warning:
  760. ``memory_is_causal`` provides a hint that
  761. ``memory_mask`` is the causal mask. Providing incorrect
  762. hints can result in incorrect execution, including
  763. forward and backward compatibility.
  764. Shape:
  765. see the docs in :class:`~torch.nn.Transformer`.
  766. """
  767. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  768. x = tgt
  769. if self.norm_first:
  770. x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
  771. x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
  772. x = x + self._ff_block(self.norm3(x))
  773. else:
  774. x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
  775. x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
  776. x = self.norm3(x + self._ff_block(x))
  777. return x
  778. # self-attention block
  779. def _sa_block(self, x: Tensor,
  780. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
  781. x = self.self_attn(x, x, x,
  782. attn_mask=attn_mask,
  783. key_padding_mask=key_padding_mask,
  784. is_causal=is_causal,
  785. need_weights=False)[0]
  786. return self.dropout1(x)
  787. # multihead attention block
  788. def _mha_block(self, x: Tensor, mem: Tensor,
  789. attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
  790. x = self.multihead_attn(x, mem, mem,
  791. attn_mask=attn_mask,
  792. key_padding_mask=key_padding_mask,
  793. is_causal=is_causal,
  794. need_weights=False)[0]
  795. return self.dropout2(x)
  796. # feed forward block
  797. def _ff_block(self, x: Tensor) -> Tensor:
  798. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  799. return self.dropout3(x)
  800. def _get_clones(module, N):
  801. # FIXME: copy.deepcopy() is not defined on nn.module
  802. return ModuleList([copy.deepcopy(module) for i in range(N)])
  803. def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
  804. if activation == "relu":
  805. return F.relu
  806. elif activation == "gelu":
  807. return F.gelu
  808. raise RuntimeError(f"activation should be relu/gelu, not {activation}")
  809. def _detect_is_causal_mask(
  810. mask: Optional[Tensor],
  811. is_causal: Optional[bool] = None,
  812. size: Optional[int] = None,
  813. ) -> bool:
  814. """Return whether the given attention mask is causal.
  815. Warning:
  816. If ``is_causal`` is not ``None``, its value will be returned as is. If a
  817. user supplies an incorrect ``is_causal`` hint,
  818. ``is_causal=False`` when the mask is in fact a causal attention.mask
  819. may lead to reduced performance relative to what would be achievable
  820. with ``is_causal=True``;
  821. ``is_causal=True`` when the mask is in fact not a causal attention.mask
  822. may lead to incorrect and unpredictable execution - in some scenarios,
  823. a causal mask may be applied based on the hint, in other execution
  824. scenarios the specified mask may be used. The choice may not appear
  825. to be deterministic, in that a number of factors like alignment,
  826. hardware SKU, etc influence the decision whether to use a mask or
  827. rely on the hint.
  828. ``size`` if not None, check whether the mask is a causal mask of the provided size
  829. Otherwise, checks for any causal mask.
  830. """
  831. # Prevent type refinement
  832. make_causal = (is_causal is True)
  833. if is_causal is None and mask is not None:
  834. sz = size if size is not None else mask.size(-2)
  835. causal_comparison = _generate_square_subsequent_mask(
  836. sz, device=mask.device, dtype=mask.dtype)
  837. # Do not use `torch.equal` so we handle batched masks by
  838. # broadcasting the comparison.
  839. if mask.size() == causal_comparison.size():
  840. make_causal = bool((mask == causal_comparison).all())
  841. else:
  842. make_causal = False
  843. return make_causal