modeling_plbart.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720
  1. # coding=utf-8
  2. # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PLBART model."""
  16. import copy
  17. import math
  18. from typing import List, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import (
  26. _prepare_4d_attention_mask,
  27. _prepare_4d_attention_mask_for_sdpa,
  28. _prepare_4d_causal_attention_mask,
  29. _prepare_4d_causal_attention_mask_for_sdpa,
  30. )
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPastAndCrossAttentions,
  34. CausalLMOutputWithCrossAttentions,
  35. Seq2SeqLMOutput,
  36. Seq2SeqModelOutput,
  37. Seq2SeqSequenceClassifierOutput,
  38. )
  39. from ...modeling_utils import PreTrainedModel
  40. from ...utils import (
  41. add_code_sample_docstrings,
  42. add_end_docstrings,
  43. add_start_docstrings,
  44. add_start_docstrings_to_model_forward,
  45. logging,
  46. replace_return_docstrings,
  47. )
  48. from .configuration_plbart import PLBartConfig
  49. logger = logging.get_logger(__name__)
  50. _CHECKPOINT_FOR_DOC = "uclanlp/plbart-base"
  51. _CONFIG_FOR_DOC = "PLBartConfig"
  52. # Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right
  53. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
  54. """
  55. Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
  56. have a single `decoder_start_token_id` in contrast to other Bart-like models.
  57. """
  58. prev_output_tokens = input_ids.clone()
  59. if pad_token_id is None:
  60. raise ValueError("self.model.config.pad_token_id has to be defined.")
  61. # replace possible -100 values in labels by `pad_token_id`
  62. prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
  63. index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  64. decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
  65. prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
  66. prev_output_tokens[:, 0] = decoder_start_tokens
  67. return prev_output_tokens
  68. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
  69. class PLBartLearnedPositionalEmbedding(nn.Embedding):
  70. """
  71. This module learns positional embeddings up to a fixed maximum size.
  72. """
  73. def __init__(self, num_embeddings: int, embedding_dim: int):
  74. # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2
  75. # and adjust num_embeddings appropriately. Other models don't have this hack
  76. self.offset = 2
  77. super().__init__(num_embeddings + self.offset, embedding_dim)
  78. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  79. """`input_ids' shape is expected to be [bsz x seqlen]."""
  80. bsz, seq_len = input_ids.shape[:2]
  81. positions = torch.arange(
  82. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  83. ).expand(bsz, -1)
  84. return super().forward(positions + self.offset)
  85. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart
  86. class PLBartScaledWordEmbedding(nn.Embedding):
  87. """
  88. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  89. """
  90. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  91. super().__init__(num_embeddings, embedding_dim, padding_idx)
  92. self.embed_scale = embed_scale
  93. def forward(self, input_ids: torch.Tensor):
  94. return super().forward(input_ids) * self.embed_scale
  95. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart
  96. class PLBartAttention(nn.Module):
  97. """Multi-headed attention from 'Attention Is All You Need' paper"""
  98. def __init__(
  99. self,
  100. embed_dim: int,
  101. num_heads: int,
  102. dropout: float = 0.0,
  103. is_decoder: bool = False,
  104. bias: bool = True,
  105. is_causal: bool = False,
  106. config: Optional[PLBartConfig] = None,
  107. ):
  108. super().__init__()
  109. self.embed_dim = embed_dim
  110. self.num_heads = num_heads
  111. self.dropout = dropout
  112. self.head_dim = embed_dim // num_heads
  113. self.config = config
  114. if (self.head_dim * num_heads) != self.embed_dim:
  115. raise ValueError(
  116. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  117. f" and `num_heads`: {num_heads})."
  118. )
  119. self.scaling = self.head_dim**-0.5
  120. self.is_decoder = is_decoder
  121. self.is_causal = is_causal
  122. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  123. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  124. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  125. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  126. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  127. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  128. def forward(
  129. self,
  130. hidden_states: torch.Tensor,
  131. key_value_states: Optional[torch.Tensor] = None,
  132. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  133. attention_mask: Optional[torch.Tensor] = None,
  134. layer_head_mask: Optional[torch.Tensor] = None,
  135. output_attentions: bool = False,
  136. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  137. """Input shape: Batch x Time x Channel"""
  138. # if key_value_states are provided this layer is used as a cross-attention layer
  139. # for the decoder
  140. is_cross_attention = key_value_states is not None
  141. bsz, tgt_len, _ = hidden_states.size()
  142. # get query proj
  143. query_states = self.q_proj(hidden_states) * self.scaling
  144. # get key, value proj
  145. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  146. # is checking that the `sequence_length` of the `past_key_value` is the same as
  147. # the provided `key_value_states` to support prefix tuning
  148. if (
  149. is_cross_attention
  150. and past_key_value is not None
  151. and past_key_value[0].shape[2] == key_value_states.shape[1]
  152. ):
  153. # reuse k,v, cross_attentions
  154. key_states = past_key_value[0]
  155. value_states = past_key_value[1]
  156. elif is_cross_attention:
  157. # cross_attentions
  158. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  159. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  160. elif past_key_value is not None:
  161. # reuse k, v, self_attention
  162. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  163. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  164. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  165. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  166. else:
  167. # self_attention
  168. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  169. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  170. if self.is_decoder:
  171. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  172. # Further calls to cross_attention layer can then reuse all cross-attention
  173. # key/value_states (first "if" case)
  174. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  175. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  176. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  177. # if encoder bi-directional self-attention `past_key_value` is always `None`
  178. past_key_value = (key_states, value_states)
  179. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  180. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  181. key_states = key_states.reshape(*proj_shape)
  182. value_states = value_states.reshape(*proj_shape)
  183. src_len = key_states.size(1)
  184. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  185. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  186. raise ValueError(
  187. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  188. f" {attn_weights.size()}"
  189. )
  190. if attention_mask is not None:
  191. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  192. raise ValueError(
  193. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  194. )
  195. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  196. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  197. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  198. if layer_head_mask is not None:
  199. if layer_head_mask.size() != (self.num_heads,):
  200. raise ValueError(
  201. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  202. f" {layer_head_mask.size()}"
  203. )
  204. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  205. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  206. if output_attentions:
  207. # this operation is a bit awkward, but it's required to
  208. # make sure that attn_weights keeps its gradient.
  209. # In order to do so, attn_weights have to be reshaped
  210. # twice and have to be reused in the following
  211. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  212. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  213. else:
  214. attn_weights_reshaped = None
  215. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  216. attn_output = torch.bmm(attn_probs, value_states)
  217. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  218. raise ValueError(
  219. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  220. f" {attn_output.size()}"
  221. )
  222. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  223. attn_output = attn_output.transpose(1, 2)
  224. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  225. # partitioned across GPUs when using tensor-parallelism.
  226. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  227. attn_output = self.out_proj(attn_output)
  228. return attn_output, attn_weights_reshaped, past_key_value
  229. # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART
  230. class PLBartEncoderLayer(nn.Module):
  231. def __init__(self, config: PLBartConfig):
  232. super().__init__()
  233. self.embed_dim = config.d_model
  234. self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
  235. embed_dim=self.embed_dim,
  236. num_heads=config.encoder_attention_heads,
  237. dropout=config.attention_dropout,
  238. config=config,
  239. )
  240. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  241. self.dropout = config.dropout
  242. self.activation_fn = ACT2FN[config.activation_function]
  243. self.activation_dropout = config.activation_dropout
  244. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  245. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  246. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  247. def forward(
  248. self,
  249. hidden_states: torch.FloatTensor,
  250. attention_mask: torch.FloatTensor,
  251. layer_head_mask: torch.FloatTensor,
  252. output_attentions: Optional[bool] = False,
  253. ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  254. """
  255. Args:
  256. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  257. attention_mask (`torch.FloatTensor`): attention mask of size
  258. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  259. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  260. `(encoder_attention_heads,)`.
  261. output_attentions (`bool`, *optional*):
  262. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  263. returned tensors for more detail.
  264. """
  265. residual = hidden_states
  266. hidden_states, attn_weights, _ = self.self_attn(
  267. hidden_states=hidden_states,
  268. attention_mask=attention_mask,
  269. layer_head_mask=layer_head_mask,
  270. output_attentions=output_attentions,
  271. )
  272. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  273. hidden_states = residual + hidden_states
  274. hidden_states = self.self_attn_layer_norm(hidden_states)
  275. residual = hidden_states
  276. hidden_states = self.activation_fn(self.fc1(hidden_states))
  277. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  278. hidden_states = self.fc2(hidden_states)
  279. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  280. hidden_states = residual + hidden_states
  281. hidden_states = self.final_layer_norm(hidden_states)
  282. if hidden_states.dtype == torch.float16 and (
  283. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  284. ):
  285. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  286. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  287. outputs = (hidden_states,)
  288. if output_attentions:
  289. outputs += (attn_weights,)
  290. return outputs
  291. # TODO: Implement attention with SDPA for PLBart.
  292. PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
  293. # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
  294. class PLBartDecoderLayer(nn.Module):
  295. def __init__(self, config: PLBartConfig):
  296. super().__init__()
  297. self.embed_dim = config.d_model
  298. self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
  299. embed_dim=self.embed_dim,
  300. num_heads=config.decoder_attention_heads,
  301. dropout=config.attention_dropout,
  302. is_decoder=True,
  303. is_causal=True,
  304. config=config,
  305. )
  306. self.dropout = config.dropout
  307. self.activation_fn = ACT2FN[config.activation_function]
  308. self.activation_dropout = config.activation_dropout
  309. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  310. self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
  311. self.embed_dim,
  312. config.decoder_attention_heads,
  313. dropout=config.attention_dropout,
  314. is_decoder=True,
  315. config=config,
  316. )
  317. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  318. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  319. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  320. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  321. def forward(
  322. self,
  323. hidden_states: torch.Tensor,
  324. attention_mask: Optional[torch.Tensor] = None,
  325. encoder_hidden_states: Optional[torch.Tensor] = None,
  326. encoder_attention_mask: Optional[torch.Tensor] = None,
  327. layer_head_mask: Optional[torch.Tensor] = None,
  328. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  329. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  330. output_attentions: Optional[bool] = False,
  331. use_cache: Optional[bool] = True,
  332. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  333. """
  334. Args:
  335. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  336. attention_mask (`torch.FloatTensor`): attention mask of size
  337. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  338. encoder_hidden_states (`torch.FloatTensor`):
  339. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  340. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  341. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  342. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  343. `(encoder_attention_heads,)`.
  344. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  345. size `(decoder_attention_heads,)`.
  346. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
  347. output_attentions (`bool`, *optional*):
  348. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  349. returned tensors for more detail.
  350. """
  351. residual = hidden_states
  352. # Self Attention
  353. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  354. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  355. # add present self-attn cache to positions 1,2 of present_key_value tuple
  356. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  357. hidden_states=hidden_states,
  358. past_key_value=self_attn_past_key_value,
  359. attention_mask=attention_mask,
  360. layer_head_mask=layer_head_mask,
  361. output_attentions=output_attentions,
  362. )
  363. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  364. hidden_states = residual + hidden_states
  365. hidden_states = self.self_attn_layer_norm(hidden_states)
  366. # Cross-Attention Block
  367. cross_attn_present_key_value = None
  368. cross_attn_weights = None
  369. if encoder_hidden_states is not None:
  370. residual = hidden_states
  371. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  372. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  373. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  374. hidden_states=hidden_states,
  375. key_value_states=encoder_hidden_states,
  376. attention_mask=encoder_attention_mask,
  377. layer_head_mask=cross_attn_layer_head_mask,
  378. past_key_value=cross_attn_past_key_value,
  379. output_attentions=output_attentions,
  380. )
  381. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  382. hidden_states = residual + hidden_states
  383. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  384. # add cross-attn to positions 3,4 of present_key_value tuple
  385. present_key_value = present_key_value + cross_attn_present_key_value
  386. # Fully Connected
  387. residual = hidden_states
  388. hidden_states = self.activation_fn(self.fc1(hidden_states))
  389. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  390. hidden_states = self.fc2(hidden_states)
  391. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  392. hidden_states = residual + hidden_states
  393. hidden_states = self.final_layer_norm(hidden_states)
  394. outputs = (hidden_states,)
  395. if output_attentions:
  396. outputs += (self_attn_weights, cross_attn_weights)
  397. if use_cache:
  398. outputs += (present_key_value,)
  399. return outputs
  400. # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart
  401. class PLBartClassificationHead(nn.Module):
  402. """Head for sentence-level classification tasks."""
  403. def __init__(
  404. self,
  405. input_dim: int,
  406. inner_dim: int,
  407. num_classes: int,
  408. pooler_dropout: float,
  409. ):
  410. super().__init__()
  411. self.dense = nn.Linear(input_dim, inner_dim)
  412. self.dropout = nn.Dropout(p=pooler_dropout)
  413. self.out_proj = nn.Linear(inner_dim, num_classes)
  414. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  415. hidden_states = self.dropout(hidden_states)
  416. hidden_states = self.dense(hidden_states)
  417. hidden_states = torch.tanh(hidden_states)
  418. hidden_states = self.dropout(hidden_states)
  419. hidden_states = self.out_proj(hidden_states)
  420. return hidden_states
  421. class PLBartPreTrainedModel(PreTrainedModel):
  422. config_class = PLBartConfig
  423. base_model_prefix = "model"
  424. supports_gradient_checkpointing = True
  425. _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
  426. def _init_weights(self, module):
  427. std = self.config.init_std
  428. if isinstance(module, nn.Linear):
  429. module.weight.data.normal_(mean=0.0, std=std)
  430. if module.bias is not None:
  431. module.bias.data.zero_()
  432. elif isinstance(module, nn.Embedding):
  433. module.weight.data.normal_(mean=0.0, std=std)
  434. if module.padding_idx is not None:
  435. module.weight.data[module.padding_idx].zero_()
  436. PLBART_START_DOCSTRING = r"""
  437. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  438. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  439. etc.)
  440. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  441. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  442. and behavior.
  443. Parameters:
  444. config ([`PLBartConfig`]):
  445. Model configuration class with all the parameters of the model. Initializing with a config file does not
  446. load the weights associated with the model, only the configuration. Check out the
  447. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  448. """
  449. PLBART_GENERATION_EXAMPLE = r"""
  450. Mask-filling example:
  451. ```python
  452. >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
  453. >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
  454. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  455. >>> # en_XX is the language symbol id <LID> for English
  456. >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
  457. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
  458. >>> logits = model(input_ids).logits
  459. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  460. >>> probs = logits[0, masked_index].softmax(dim=0)
  461. >>> values, predictions = probs.topk(5)
  462. >>> tokenizer.decode(predictions).split()
  463. ['first', 'same', 'highest', 'result', 'number']
  464. ```
  465. """
  466. PLBART_INPUTS_DOCSTRING = r"""
  467. Args:
  468. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  469. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  470. it.
  471. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  472. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  473. [What are input IDs?](../glossary#input-ids)
  474. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  475. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  476. - 1 for tokens that are **not masked**,
  477. - 0 for tokens that are **masked**.
  478. [What are attention masks?](../glossary#attention-mask)
  479. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  480. Indices of decoder input sequence tokens in the vocabulary.
  481. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  482. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  483. [What are decoder input IDs?](../glossary#decoder-input-ids)
  484. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  485. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  486. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  487. `past_key_values`).
  488. For translation and summarization training, `decoder_input_ids` should be provided. If no
  489. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  490. for denoising pre-training following the paper.
  491. decoder_attention_mask (:
  492. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): Default behavior:
  493. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  494. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  495. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  496. - 1 indicates the head is **not masked**,
  497. - 0 indicates the head is **masked**.
  498. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  499. Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
  500. - 1 indicates the head is **not masked**,
  501. - 0 indicates the head is **masked**.
  502. cross_attn_head_mask (:
  503. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify
  504. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  505. - 1 indicates the head is **not masked**,
  506. - 0 indicates the head is **masked**.
  507. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  508. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  509. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  510. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  511. past_key_values (:
  512. obj:*tuple(tuple(torch.FloatTensor))*, *optional*, returned when `use_cache=True` is passed or when
  513. `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple
  514. having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
  515. tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  516. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  517. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  518. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  519. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  520. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  521. inputs_embeds (:
  522. obj:*torch.FloatTensor* of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally,
  523. instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful
  524. if you want more control over how to convert `input_ids` indices into associated vectors than the model's
  525. internal embedding lookup matrix.
  526. decoder_inputs_embeds (:
  527. obj:*torch.FloatTensor* of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  528. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  529. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  530. input (see `past_key_values`). This is useful if you want more control over how to convert
  531. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  532. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  533. of `inputs_embeds`.
  534. use_cache (`bool`, *optional*):
  535. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  536. `past_key_values`).
  537. output_attentions (`bool`, *optional*):
  538. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  539. tensors for more detail.
  540. output_hidden_states (`bool`, *optional*):
  541. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  542. more detail.
  543. return_dict (`bool`, *optional*):
  544. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  545. """
  546. # Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart
  547. class PLBartEncoder(PLBartPreTrainedModel):
  548. """
  549. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  550. [`PLBartEncoderLayer`].
  551. Args:
  552. config: PLBartConfig
  553. embed_tokens (nn.Embedding): output embedding
  554. """
  555. def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  556. super().__init__(config)
  557. self.dropout = config.dropout
  558. self.layerdrop = config.encoder_layerdrop
  559. embed_dim = config.d_model
  560. self.padding_idx = config.pad_token_id
  561. self.max_source_positions = config.max_position_embeddings
  562. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  563. self.embed_tokens = PLBartScaledWordEmbedding(
  564. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  565. )
  566. if embed_tokens is not None:
  567. self.embed_tokens.weight = embed_tokens.weight
  568. self.embed_positions = PLBartLearnedPositionalEmbedding(
  569. config.max_position_embeddings,
  570. embed_dim,
  571. )
  572. self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
  573. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  574. self._use_sdpa = config._attn_implementation == "sdpa"
  575. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  576. self.gradient_checkpointing = False
  577. # Initialize weights and apply final processing
  578. self.post_init()
  579. def get_input_embeddings(self):
  580. return self.embed_tokens
  581. def set_input_embeddings(self, value):
  582. self.embed_tokens = value
  583. def forward(
  584. self,
  585. input_ids: torch.LongTensor = None,
  586. attention_mask: Optional[torch.Tensor] = None,
  587. head_mask: Optional[torch.Tensor] = None,
  588. inputs_embeds: Optional[torch.FloatTensor] = None,
  589. output_attentions: Optional[bool] = None,
  590. output_hidden_states: Optional[bool] = None,
  591. return_dict: Optional[bool] = None,
  592. ) -> Union[Tuple, BaseModelOutput]:
  593. r"""
  594. Args:
  595. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  596. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  597. provide it.
  598. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  599. [`PreTrainedTokenizer.__call__`] for details.
  600. [What are input IDs?](../glossary#input-ids)
  601. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  602. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  603. - 1 for tokens that are **not masked**,
  604. - 0 for tokens that are **masked**.
  605. [What are attention masks?](../glossary#attention-mask)
  606. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  607. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  608. - 1 indicates the head is **not masked**,
  609. - 0 indicates the head is **masked**.
  610. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  611. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  612. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  613. than the model's internal embedding lookup matrix.
  614. output_attentions (`bool`, *optional*):
  615. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  616. returned tensors for more detail.
  617. output_hidden_states (`bool`, *optional*):
  618. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  619. for more detail.
  620. return_dict (`bool`, *optional*):
  621. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  622. """
  623. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  624. output_hidden_states = (
  625. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  626. )
  627. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  628. # retrieve input_ids and inputs_embeds
  629. if input_ids is not None and inputs_embeds is not None:
  630. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  631. elif input_ids is not None:
  632. input = input_ids
  633. input_ids = input_ids.view(-1, input_ids.shape[-1])
  634. elif inputs_embeds is not None:
  635. input = inputs_embeds[:, :, -1]
  636. else:
  637. raise ValueError("You have to specify either input_ids or inputs_embeds")
  638. if inputs_embeds is None:
  639. inputs_embeds = self.embed_tokens(input_ids)
  640. embed_pos = self.embed_positions(input)
  641. embed_pos = embed_pos.to(inputs_embeds.device)
  642. hidden_states = inputs_embeds + embed_pos
  643. hidden_states = self.layernorm_embedding(hidden_states)
  644. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  645. # expand attention_mask
  646. if attention_mask is not None:
  647. if self._use_flash_attention_2:
  648. attention_mask = attention_mask if 0 in attention_mask else None
  649. elif self._use_sdpa and head_mask is None and not output_attentions:
  650. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  651. # the manual implementation that requires a 4D causal mask in all cases.
  652. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  653. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  654. else:
  655. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  656. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  657. encoder_states = () if output_hidden_states else None
  658. all_attentions = () if output_attentions else None
  659. # check if head_mask has a correct number of layers specified if desired
  660. if head_mask is not None:
  661. if head_mask.size()[0] != (len(self.layers)):
  662. raise ValueError(
  663. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  664. f" {head_mask.size()[0]}."
  665. )
  666. for idx, encoder_layer in enumerate(self.layers):
  667. if output_hidden_states:
  668. encoder_states = encoder_states + (hidden_states,)
  669. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  670. to_drop = False
  671. if self.training:
  672. dropout_probability = torch.rand([])
  673. if dropout_probability < self.layerdrop: # skip the layer
  674. to_drop = True
  675. if to_drop:
  676. layer_outputs = (None, None)
  677. else:
  678. if self.gradient_checkpointing and self.training:
  679. layer_outputs = self._gradient_checkpointing_func(
  680. encoder_layer.__call__,
  681. hidden_states,
  682. attention_mask,
  683. (head_mask[idx] if head_mask is not None else None),
  684. output_attentions,
  685. )
  686. else:
  687. layer_outputs = encoder_layer(
  688. hidden_states,
  689. attention_mask,
  690. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  691. output_attentions=output_attentions,
  692. )
  693. hidden_states = layer_outputs[0]
  694. if output_attentions:
  695. all_attentions = all_attentions + (layer_outputs[1],)
  696. if output_hidden_states:
  697. encoder_states = encoder_states + (hidden_states,)
  698. if not return_dict:
  699. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  700. return BaseModelOutput(
  701. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  702. )
  703. # Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart
  704. class PLBartDecoder(PLBartPreTrainedModel):
  705. """
  706. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`]
  707. Args:
  708. config: PLBartConfig
  709. embed_tokens (nn.Embedding): output embedding
  710. """
  711. def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  712. super().__init__(config)
  713. self.dropout = config.dropout
  714. self.layerdrop = config.decoder_layerdrop
  715. self.padding_idx = config.pad_token_id
  716. self.max_target_positions = config.max_position_embeddings
  717. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  718. self.embed_tokens = PLBartScaledWordEmbedding(
  719. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  720. )
  721. if embed_tokens is not None:
  722. self.embed_tokens.weight = embed_tokens.weight
  723. self.embed_positions = PLBartLearnedPositionalEmbedding(
  724. config.max_position_embeddings,
  725. config.d_model,
  726. )
  727. self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
  728. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  729. self._use_sdpa = config._attn_implementation == "sdpa"
  730. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  731. self.gradient_checkpointing = False
  732. # Initialize weights and apply final processing
  733. self.post_init()
  734. def get_input_embeddings(self):
  735. return self.embed_tokens
  736. def set_input_embeddings(self, value):
  737. self.embed_tokens = value
  738. def forward(
  739. self,
  740. input_ids: torch.LongTensor = None,
  741. attention_mask: Optional[torch.Tensor] = None,
  742. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  743. encoder_attention_mask: Optional[torch.LongTensor] = None,
  744. head_mask: Optional[torch.Tensor] = None,
  745. cross_attn_head_mask: Optional[torch.Tensor] = None,
  746. past_key_values: Optional[List[torch.FloatTensor]] = None,
  747. inputs_embeds: Optional[torch.FloatTensor] = None,
  748. use_cache: Optional[bool] = None,
  749. output_attentions: Optional[bool] = None,
  750. output_hidden_states: Optional[bool] = None,
  751. return_dict: Optional[bool] = None,
  752. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  753. r"""
  754. Args:
  755. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  756. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  757. provide it.
  758. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  759. [`PreTrainedTokenizer.__call__`] for details.
  760. [What are input IDs?](../glossary#input-ids)
  761. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  762. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  763. - 1 for tokens that are **not masked**,
  764. - 0 for tokens that are **masked**.
  765. [What are attention masks?](../glossary#attention-mask)
  766. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  767. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  768. of the decoder.
  769. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  770. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  771. selected in `[0, 1]`:
  772. - 1 for tokens that are **not masked**,
  773. - 0 for tokens that are **masked**.
  774. [What are attention masks?](../glossary#attention-mask)
  775. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  776. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  777. - 1 indicates the head is **not masked**,
  778. - 0 indicates the head is **masked**.
  779. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  780. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  781. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  782. - 1 indicates the head is **not masked**,
  783. - 0 indicates the head is **masked**.
  784. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  785. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  786. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  787. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  788. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  789. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  790. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  791. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  792. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  793. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  794. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  795. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  796. than the model's internal embedding lookup matrix.
  797. output_attentions (`bool`, *optional*):
  798. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  799. returned tensors for more detail.
  800. output_hidden_states (`bool`, *optional*):
  801. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  802. for more detail.
  803. return_dict (`bool`, *optional*):
  804. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  805. """
  806. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  807. output_hidden_states = (
  808. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  809. )
  810. use_cache = use_cache if use_cache is not None else self.config.use_cache
  811. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  812. # retrieve input_ids and inputs_embeds
  813. if input_ids is not None and inputs_embeds is not None:
  814. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  815. elif input_ids is not None:
  816. input = input_ids
  817. input_shape = input.shape
  818. input_ids = input_ids.view(-1, input_shape[-1])
  819. elif inputs_embeds is not None:
  820. input_shape = inputs_embeds.size()[:-1]
  821. input = inputs_embeds[:, :, -1]
  822. else:
  823. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  824. # past_key_values_length
  825. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  826. if inputs_embeds is None:
  827. inputs_embeds = self.embed_tokens(input)
  828. if self._use_flash_attention_2:
  829. # 2d mask is passed through the layers
  830. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  831. elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
  832. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  833. # the manual implementation that requires a 4D causal mask in all cases.
  834. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  835. attention_mask,
  836. input_shape,
  837. inputs_embeds,
  838. past_key_values_length,
  839. )
  840. else:
  841. # 4d mask is passed through the layers
  842. attention_mask = _prepare_4d_causal_attention_mask(
  843. attention_mask, input_shape, inputs_embeds, past_key_values_length
  844. )
  845. # expand encoder attention mask
  846. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  847. if self._use_flash_attention_2:
  848. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  849. elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
  850. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  851. # the manual implementation that requires a 4D causal mask in all cases.
  852. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  853. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  854. encoder_attention_mask,
  855. inputs_embeds.dtype,
  856. tgt_len=input_shape[-1],
  857. )
  858. else:
  859. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  860. encoder_attention_mask = _prepare_4d_attention_mask(
  861. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  862. )
  863. # embed positions
  864. positions = self.embed_positions(input, past_key_values_length)
  865. positions = positions.to(inputs_embeds.device)
  866. hidden_states = inputs_embeds + positions
  867. hidden_states = self.layernorm_embedding(hidden_states)
  868. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  869. if self.gradient_checkpointing and self.training:
  870. if use_cache:
  871. logger.warning_once(
  872. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  873. )
  874. use_cache = False
  875. # decoder layers
  876. all_hidden_states = () if output_hidden_states else None
  877. all_self_attns = () if output_attentions else None
  878. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  879. next_decoder_cache = () if use_cache else None
  880. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  881. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  882. if attn_mask is not None:
  883. if attn_mask.size()[0] != (len(self.layers)):
  884. raise ValueError(
  885. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  886. f" {head_mask.size()[0]}."
  887. )
  888. for idx, decoder_layer in enumerate(self.layers):
  889. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  890. if output_hidden_states:
  891. all_hidden_states += (hidden_states,)
  892. if self.training:
  893. dropout_probability = torch.rand([])
  894. if dropout_probability < self.layerdrop:
  895. continue
  896. past_key_value = past_key_values[idx] if past_key_values is not None else None
  897. if self.gradient_checkpointing and self.training:
  898. layer_outputs = self._gradient_checkpointing_func(
  899. decoder_layer.__call__,
  900. hidden_states,
  901. attention_mask,
  902. encoder_hidden_states,
  903. encoder_attention_mask,
  904. head_mask[idx] if head_mask is not None else None,
  905. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
  906. None,
  907. output_attentions,
  908. use_cache,
  909. )
  910. else:
  911. layer_outputs = decoder_layer(
  912. hidden_states,
  913. attention_mask=attention_mask,
  914. encoder_hidden_states=encoder_hidden_states,
  915. encoder_attention_mask=encoder_attention_mask,
  916. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  917. cross_attn_layer_head_mask=(
  918. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  919. ),
  920. past_key_value=past_key_value,
  921. output_attentions=output_attentions,
  922. use_cache=use_cache,
  923. )
  924. hidden_states = layer_outputs[0]
  925. if use_cache:
  926. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  927. if output_attentions:
  928. all_self_attns += (layer_outputs[1],)
  929. if encoder_hidden_states is not None:
  930. all_cross_attentions += (layer_outputs[2],)
  931. # add hidden states from the last decoder layer
  932. if output_hidden_states:
  933. all_hidden_states += (hidden_states,)
  934. next_cache = next_decoder_cache if use_cache else None
  935. if not return_dict:
  936. return tuple(
  937. v
  938. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
  939. if v is not None
  940. )
  941. return BaseModelOutputWithPastAndCrossAttentions(
  942. last_hidden_state=hidden_states,
  943. past_key_values=next_cache,
  944. hidden_states=all_hidden_states,
  945. attentions=all_self_attns,
  946. cross_attentions=all_cross_attentions,
  947. )
  948. @add_start_docstrings(
  949. "The bare PLBART Model outputting raw hidden-states without any specific head on top.",
  950. PLBART_START_DOCSTRING,
  951. )
  952. class PLBartModel(PLBartPreTrainedModel):
  953. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  954. def __init__(self, config: PLBartConfig):
  955. super().__init__(config)
  956. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  957. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  958. self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  959. self.encoder = PLBartEncoder(config, self.shared)
  960. self.decoder = PLBartDecoder(config, self.shared)
  961. self.init_weights()
  962. def get_input_embeddings(self):
  963. return self.shared
  964. def set_input_embeddings(self, value):
  965. self.shared = value
  966. self.encoder.embed_tokens = self.shared
  967. self.decoder.embed_tokens = self.shared
  968. def _tie_weights(self):
  969. if self.config.tie_word_embeddings:
  970. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  971. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  972. def get_encoder(self):
  973. return self.encoder
  974. def get_decoder(self):
  975. return self.decoder
  976. @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
  977. @add_code_sample_docstrings(
  978. checkpoint=_CHECKPOINT_FOR_DOC,
  979. output_type=Seq2SeqModelOutput,
  980. config_class=_CONFIG_FOR_DOC,
  981. )
  982. def forward(
  983. self,
  984. input_ids: Optional[torch.LongTensor] = None,
  985. attention_mask: Optional[torch.LongTensor] = None,
  986. decoder_input_ids: Optional[torch.LongTensor] = None,
  987. decoder_attention_mask: Optional[torch.Tensor] = None,
  988. head_mask: Optional[torch.Tensor] = None,
  989. decoder_head_mask: Optional[torch.LongTensor] = None,
  990. cross_attn_head_mask: Optional[torch.Tensor] = None,
  991. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  992. past_key_values: Optional[List[torch.FloatTensor]] = None,
  993. inputs_embeds: Optional[torch.FloatTensor] = None,
  994. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  995. use_cache: Optional[bool] = None,
  996. output_attentions: Optional[bool] = None,
  997. output_hidden_states: Optional[bool] = None,
  998. return_dict: Optional[bool] = None,
  999. ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
  1000. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1001. output_hidden_states = (
  1002. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1003. )
  1004. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1005. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1006. # different to other models, PLBart automatically creates decoder_input_ids from
  1007. # input_ids if no decoder_input_ids are provided
  1008. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1009. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  1010. if encoder_outputs is None:
  1011. encoder_outputs = self.encoder(
  1012. input_ids=input_ids,
  1013. attention_mask=attention_mask,
  1014. head_mask=head_mask,
  1015. inputs_embeds=inputs_embeds,
  1016. output_attentions=output_attentions,
  1017. output_hidden_states=output_hidden_states,
  1018. return_dict=return_dict,
  1019. )
  1020. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1021. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1022. encoder_outputs = BaseModelOutput(
  1023. last_hidden_state=encoder_outputs[0],
  1024. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1025. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1026. )
  1027. # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  1028. decoder_outputs = self.decoder(
  1029. input_ids=decoder_input_ids,
  1030. attention_mask=decoder_attention_mask,
  1031. encoder_hidden_states=encoder_outputs[0],
  1032. encoder_attention_mask=attention_mask,
  1033. head_mask=decoder_head_mask,
  1034. cross_attn_head_mask=cross_attn_head_mask,
  1035. past_key_values=past_key_values,
  1036. inputs_embeds=decoder_inputs_embeds,
  1037. use_cache=use_cache,
  1038. output_attentions=output_attentions,
  1039. output_hidden_states=output_hidden_states,
  1040. return_dict=return_dict,
  1041. )
  1042. if not return_dict:
  1043. return decoder_outputs + encoder_outputs
  1044. return Seq2SeqModelOutput(
  1045. last_hidden_state=decoder_outputs.last_hidden_state,
  1046. past_key_values=decoder_outputs.past_key_values,
  1047. decoder_hidden_states=decoder_outputs.hidden_states,
  1048. decoder_attentions=decoder_outputs.attentions,
  1049. cross_attentions=decoder_outputs.cross_attentions,
  1050. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1051. encoder_hidden_states=encoder_outputs.hidden_states,
  1052. encoder_attentions=encoder_outputs.attentions,
  1053. )
  1054. @add_start_docstrings(
  1055. "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.",
  1056. PLBART_START_DOCSTRING,
  1057. )
  1058. class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
  1059. base_model_prefix = "model"
  1060. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1061. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1062. def __init__(self, config: PLBartConfig):
  1063. super().__init__(config)
  1064. self.model = PLBartModel(config)
  1065. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1066. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1067. self.init_weights()
  1068. def get_encoder(self):
  1069. return self.model.get_encoder()
  1070. def get_decoder(self):
  1071. return self.model.get_decoder()
  1072. def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
  1073. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
  1074. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1075. return new_embeddings
  1076. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1077. old_num_tokens = self.final_logits_bias.shape[-1]
  1078. if new_num_tokens <= old_num_tokens:
  1079. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1080. else:
  1081. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1082. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1083. self.register_buffer("final_logits_bias", new_bias)
  1084. def get_output_embeddings(self):
  1085. return self.lm_head
  1086. def set_output_embeddings(self, new_embeddings):
  1087. self.lm_head = new_embeddings
  1088. @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
  1089. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1090. @add_end_docstrings(PLBART_GENERATION_EXAMPLE)
  1091. def forward(
  1092. self,
  1093. input_ids: Optional[torch.LongTensor] = None,
  1094. attention_mask: Optional[torch.LongTensor] = None,
  1095. decoder_input_ids: Optional[torch.LongTensor] = None,
  1096. decoder_attention_mask: Optional[torch.Tensor] = None,
  1097. head_mask: Optional[torch.Tensor] = None,
  1098. decoder_head_mask: Optional[torch.LongTensor] = None,
  1099. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1100. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1101. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1102. inputs_embeds: Optional[torch.FloatTensor] = None,
  1103. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1104. labels: Optional[torch.Tensor] = None,
  1105. use_cache: Optional[bool] = None,
  1106. output_attentions: Optional[bool] = None,
  1107. output_hidden_states: Optional[bool] = None,
  1108. return_dict: Optional[bool] = None,
  1109. ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
  1110. r"""
  1111. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1112. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1113. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1114. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1115. Returns:
  1116. """
  1117. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1118. if labels is not None:
  1119. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1120. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  1121. outputs = self.model(
  1122. input_ids,
  1123. attention_mask=attention_mask,
  1124. decoder_input_ids=decoder_input_ids,
  1125. encoder_outputs=encoder_outputs,
  1126. decoder_attention_mask=decoder_attention_mask,
  1127. head_mask=head_mask,
  1128. decoder_head_mask=decoder_head_mask,
  1129. cross_attn_head_mask=cross_attn_head_mask,
  1130. past_key_values=past_key_values,
  1131. inputs_embeds=inputs_embeds,
  1132. decoder_inputs_embeds=decoder_inputs_embeds,
  1133. use_cache=use_cache,
  1134. output_attentions=output_attentions,
  1135. output_hidden_states=output_hidden_states,
  1136. return_dict=return_dict,
  1137. )
  1138. lm_logits = self.lm_head(outputs[0])
  1139. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  1140. masked_lm_loss = None
  1141. if labels is not None:
  1142. loss_fct = CrossEntropyLoss()
  1143. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1144. if not return_dict:
  1145. output = (lm_logits,) + outputs[1:]
  1146. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1147. return Seq2SeqLMOutput(
  1148. loss=masked_lm_loss,
  1149. logits=lm_logits,
  1150. past_key_values=outputs.past_key_values,
  1151. decoder_hidden_states=outputs.decoder_hidden_states,
  1152. decoder_attentions=outputs.decoder_attentions,
  1153. cross_attentions=outputs.cross_attentions,
  1154. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1155. encoder_hidden_states=outputs.encoder_hidden_states,
  1156. encoder_attentions=outputs.encoder_attentions,
  1157. )
  1158. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1159. return shift_tokens_right(labels, self.config.pad_token_id)
  1160. @staticmethod
  1161. def _reorder_cache(past_key_values, beam_idx):
  1162. reordered_past = ()
  1163. for layer_past in past_key_values:
  1164. # cached cross_attention states don't have to be reordered -> they are always the same
  1165. reordered_past += (
  1166. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
  1167. + layer_past[2:],
  1168. )
  1169. return reordered_past
  1170. @add_start_docstrings(
  1171. """
  1172. PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code
  1173. classification.
  1174. """,
  1175. PLBART_START_DOCSTRING,
  1176. )
  1177. class PLBartForSequenceClassification(PLBartPreTrainedModel):
  1178. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1179. def __init__(self, config: PLBartConfig, **kwargs):
  1180. super().__init__(config, **kwargs)
  1181. self.model = PLBartModel(config)
  1182. self.classification_head = PLBartClassificationHead(
  1183. config.d_model,
  1184. config.d_model,
  1185. config.num_labels,
  1186. config.classifier_dropout,
  1187. )
  1188. # Initialize weights and apply final processing
  1189. self.post_init()
  1190. @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
  1191. @add_code_sample_docstrings(
  1192. checkpoint=_CHECKPOINT_FOR_DOC,
  1193. output_type=Seq2SeqSequenceClassifierOutput,
  1194. config_class=_CONFIG_FOR_DOC,
  1195. )
  1196. # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
  1197. def forward(
  1198. self,
  1199. input_ids: torch.LongTensor = None,
  1200. attention_mask: Optional[torch.Tensor] = None,
  1201. decoder_input_ids: Optional[torch.LongTensor] = None,
  1202. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1203. head_mask: Optional[torch.Tensor] = None,
  1204. decoder_head_mask: Optional[torch.Tensor] = None,
  1205. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1206. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1207. inputs_embeds: Optional[torch.FloatTensor] = None,
  1208. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1209. labels: Optional[torch.LongTensor] = None,
  1210. use_cache: Optional[bool] = None,
  1211. output_attentions: Optional[bool] = None,
  1212. output_hidden_states: Optional[bool] = None,
  1213. return_dict: Optional[bool] = None,
  1214. ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
  1215. r"""
  1216. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1217. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1218. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1219. """
  1220. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1221. if labels is not None:
  1222. use_cache = False
  1223. if input_ids is None and inputs_embeds is not None:
  1224. raise NotImplementedError(
  1225. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1226. )
  1227. outputs = self.model(
  1228. input_ids,
  1229. attention_mask=attention_mask,
  1230. decoder_input_ids=decoder_input_ids,
  1231. decoder_attention_mask=decoder_attention_mask,
  1232. head_mask=head_mask,
  1233. decoder_head_mask=decoder_head_mask,
  1234. cross_attn_head_mask=cross_attn_head_mask,
  1235. encoder_outputs=encoder_outputs,
  1236. inputs_embeds=inputs_embeds,
  1237. decoder_inputs_embeds=decoder_inputs_embeds,
  1238. use_cache=use_cache,
  1239. output_attentions=output_attentions,
  1240. output_hidden_states=output_hidden_states,
  1241. return_dict=return_dict,
  1242. )
  1243. hidden_states = outputs[0] # last hidden state
  1244. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1245. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1246. raise ValueError("All examples must have the same number of <eos> tokens.")
  1247. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1248. :, -1, :
  1249. ]
  1250. logits = self.classification_head(sentence_representation)
  1251. loss = None
  1252. if labels is not None:
  1253. labels = labels.to(logits.device)
  1254. if self.config.problem_type is None:
  1255. if self.config.num_labels == 1:
  1256. self.config.problem_type = "regression"
  1257. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1258. self.config.problem_type = "single_label_classification"
  1259. else:
  1260. self.config.problem_type = "multi_label_classification"
  1261. if self.config.problem_type == "regression":
  1262. loss_fct = MSELoss()
  1263. if self.config.num_labels == 1:
  1264. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1265. else:
  1266. loss = loss_fct(logits, labels)
  1267. elif self.config.problem_type == "single_label_classification":
  1268. loss_fct = CrossEntropyLoss()
  1269. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1270. elif self.config.problem_type == "multi_label_classification":
  1271. loss_fct = BCEWithLogitsLoss()
  1272. loss = loss_fct(logits, labels)
  1273. if not return_dict:
  1274. output = (logits,) + outputs[1:]
  1275. return ((loss,) + output) if loss is not None else output
  1276. return Seq2SeqSequenceClassifierOutput(
  1277. loss=loss,
  1278. logits=logits,
  1279. past_key_values=outputs.past_key_values,
  1280. decoder_hidden_states=outputs.decoder_hidden_states,
  1281. decoder_attentions=outputs.decoder_attentions,
  1282. cross_attentions=outputs.cross_attentions,
  1283. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1284. encoder_hidden_states=outputs.encoder_hidden_states,
  1285. encoder_attentions=outputs.encoder_attentions,
  1286. )
  1287. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart
  1288. class PLBartDecoderWrapper(PLBartPreTrainedModel):
  1289. """
  1290. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1291. used in combination with the [`EncoderDecoderModel`] framework.
  1292. """
  1293. def __init__(self, config):
  1294. super().__init__(config)
  1295. self.decoder = PLBartDecoder(config)
  1296. def forward(self, *args, **kwargs):
  1297. return self.decoder(*args, **kwargs)
  1298. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
  1299. class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin):
  1300. _tied_weights_keys = ["lm_head.weight"]
  1301. def __init__(self, config):
  1302. config = copy.deepcopy(config)
  1303. config.is_decoder = True
  1304. config.is_encoder_decoder = False
  1305. super().__init__(config)
  1306. self.model = PLBartDecoderWrapper(config)
  1307. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1308. # Initialize weights and apply final processing
  1309. self.post_init()
  1310. def get_input_embeddings(self):
  1311. return self.model.decoder.embed_tokens
  1312. def set_input_embeddings(self, value):
  1313. self.model.decoder.embed_tokens = value
  1314. def get_output_embeddings(self):
  1315. return self.lm_head
  1316. def set_output_embeddings(self, new_embeddings):
  1317. self.lm_head = new_embeddings
  1318. def set_decoder(self, decoder):
  1319. self.model.decoder = decoder
  1320. def get_decoder(self):
  1321. return self.model.decoder
  1322. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
  1323. def forward(
  1324. self,
  1325. input_ids: torch.LongTensor = None,
  1326. attention_mask: Optional[torch.Tensor] = None,
  1327. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1328. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1329. head_mask: Optional[torch.Tensor] = None,
  1330. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1331. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1332. inputs_embeds: Optional[torch.FloatTensor] = None,
  1333. labels: Optional[torch.LongTensor] = None,
  1334. use_cache: Optional[bool] = None,
  1335. output_attentions: Optional[bool] = None,
  1336. output_hidden_states: Optional[bool] = None,
  1337. return_dict: Optional[bool] = None,
  1338. ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
  1339. r"""
  1340. Args:
  1341. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1342. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1343. provide it.
  1344. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1345. [`PreTrainedTokenizer.__call__`] for details.
  1346. [What are input IDs?](../glossary#input-ids)
  1347. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1348. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1349. - 1 for tokens that are **not masked**,
  1350. - 0 for tokens that are **masked**.
  1351. [What are attention masks?](../glossary#attention-mask)
  1352. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1353. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1354. if the model is configured as a decoder.
  1355. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1356. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
  1357. in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1358. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1359. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1360. - 1 indicates the head is **not masked**,
  1361. - 0 indicates the head is **masked**.
  1362. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1363. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1364. - 1 indicates the head is **not masked**,
  1365. - 0 indicates the head is **masked**.
  1366. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1367. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  1368. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  1369. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
  1370. tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
  1371. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1372. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1373. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1374. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1375. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1376. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1377. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1378. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1379. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1380. use_cache (`bool`, *optional*):
  1381. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1382. (see `past_key_values`).
  1383. - 1 for tokens that are **not masked**,
  1384. - 0 for tokens that are **masked**.
  1385. output_attentions (`bool`, *optional*):
  1386. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1387. returned tensors for more detail.
  1388. output_hidden_states (`bool`, *optional*):
  1389. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1390. for more detail.
  1391. return_dict (`bool`, *optional*):
  1392. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1393. Returns:
  1394. Example:
  1395. ```python
  1396. >>> from transformers import AutoTokenizer, PLBartForCausalLM
  1397. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  1398. >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False)
  1399. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1400. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1401. >>> outputs = model(**inputs)
  1402. >>> logits = outputs.logits
  1403. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1404. >>> list(logits.shape) == expected_shape
  1405. True
  1406. ```"""
  1407. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1408. output_hidden_states = (
  1409. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1410. )
  1411. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1412. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1413. outputs = self.model.decoder(
  1414. input_ids=input_ids,
  1415. attention_mask=attention_mask,
  1416. encoder_hidden_states=encoder_hidden_states,
  1417. encoder_attention_mask=encoder_attention_mask,
  1418. head_mask=head_mask,
  1419. cross_attn_head_mask=cross_attn_head_mask,
  1420. past_key_values=past_key_values,
  1421. inputs_embeds=inputs_embeds,
  1422. use_cache=use_cache,
  1423. output_attentions=output_attentions,
  1424. output_hidden_states=output_hidden_states,
  1425. return_dict=return_dict,
  1426. )
  1427. logits = self.lm_head(outputs[0])
  1428. loss = None
  1429. if labels is not None:
  1430. labels = labels.to(logits.device)
  1431. loss_fct = CrossEntropyLoss()
  1432. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1433. if not return_dict:
  1434. output = (logits,) + outputs[1:]
  1435. return (loss,) + output if loss is not None else output
  1436. return CausalLMOutputWithCrossAttentions(
  1437. loss=loss,
  1438. logits=logits,
  1439. past_key_values=outputs.past_key_values,
  1440. hidden_states=outputs.hidden_states,
  1441. attentions=outputs.attentions,
  1442. cross_attentions=outputs.cross_attentions,
  1443. )
  1444. @staticmethod
  1445. def _reorder_cache(past_key_values, beam_idx):
  1446. reordered_past = ()
  1447. for layer_past in past_key_values:
  1448. reordered_past += (
  1449. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1450. )
  1451. return reordered_past