modeling_bart.py 100 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BART model."""
  16. import copy
  17. import math
  18. import warnings
  19. from typing import List, Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import (
  27. _prepare_4d_attention_mask,
  28. _prepare_4d_attention_mask_for_sdpa,
  29. _prepare_4d_causal_attention_mask,
  30. _prepare_4d_causal_attention_mask_for_sdpa,
  31. )
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. BaseModelOutputWithPastAndCrossAttentions,
  35. CausalLMOutputWithCrossAttentions,
  36. Seq2SeqLMOutput,
  37. Seq2SeqModelOutput,
  38. Seq2SeqQuestionAnsweringModelOutput,
  39. Seq2SeqSequenceClassifierOutput,
  40. )
  41. from ...modeling_utils import PreTrainedModel
  42. from ...utils import (
  43. add_code_sample_docstrings,
  44. add_end_docstrings,
  45. add_start_docstrings,
  46. add_start_docstrings_to_model_forward,
  47. is_flash_attn_2_available,
  48. is_flash_attn_greater_or_equal_2_10,
  49. logging,
  50. replace_return_docstrings,
  51. )
  52. from .configuration_bart import BartConfig
  53. if is_flash_attn_2_available():
  54. from ...modeling_flash_attention_utils import _flash_attention_forward
  55. logger = logging.get_logger(__name__)
  56. _CHECKPOINT_FOR_DOC = "facebook/bart-base"
  57. _CONFIG_FOR_DOC = "BartConfig"
  58. # Base model docstring
  59. _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
  60. # SequenceClassification docstring
  61. _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
  62. _SEQ_CLASS_EXPECTED_LOSS = 0.0
  63. _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
  64. # QuestionAsnwering docstring
  65. _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
  66. _QA_EXPECTED_LOSS = 0.59
  67. _QA_EXPECTED_OUTPUT = "' nice puppet'"
  68. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  69. """
  70. Shift input ids one token to the right.
  71. """
  72. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  73. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  74. shifted_input_ids[:, 0] = decoder_start_token_id
  75. if pad_token_id is None:
  76. raise ValueError("self.model.config.pad_token_id has to be defined.")
  77. # replace possible -100 values in labels by `pad_token_id`
  78. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  79. return shifted_input_ids
  80. class BartLearnedPositionalEmbedding(nn.Embedding):
  81. """
  82. This module learns positional embeddings up to a fixed maximum size.
  83. """
  84. def __init__(self, num_embeddings: int, embedding_dim: int):
  85. # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
  86. # and adjust num_embeddings appropriately. Other models don't have this hack
  87. self.offset = 2
  88. super().__init__(num_embeddings + self.offset, embedding_dim)
  89. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  90. """`input_ids' shape is expected to be [bsz x seqlen]."""
  91. bsz, seq_len = input_ids.shape[:2]
  92. positions = torch.arange(
  93. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  94. ).expand(bsz, -1)
  95. return super().forward(positions + self.offset)
  96. class BartScaledWordEmbedding(nn.Embedding):
  97. """
  98. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  99. """
  100. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  101. super().__init__(num_embeddings, embedding_dim, padding_idx)
  102. self.embed_scale = embed_scale
  103. def forward(self, input_ids: torch.Tensor):
  104. return super().forward(input_ids) * self.embed_scale
  105. class BartAttention(nn.Module):
  106. """Multi-headed attention from 'Attention Is All You Need' paper"""
  107. def __init__(
  108. self,
  109. embed_dim: int,
  110. num_heads: int,
  111. dropout: float = 0.0,
  112. is_decoder: bool = False,
  113. bias: bool = True,
  114. is_causal: bool = False,
  115. config: Optional[BartConfig] = None,
  116. ):
  117. super().__init__()
  118. self.embed_dim = embed_dim
  119. self.num_heads = num_heads
  120. self.dropout = dropout
  121. self.head_dim = embed_dim // num_heads
  122. self.config = config
  123. if (self.head_dim * num_heads) != self.embed_dim:
  124. raise ValueError(
  125. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  126. f" and `num_heads`: {num_heads})."
  127. )
  128. self.scaling = self.head_dim**-0.5
  129. self.is_decoder = is_decoder
  130. self.is_causal = is_causal
  131. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  132. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  133. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  134. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  135. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  136. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. key_value_states: Optional[torch.Tensor] = None,
  141. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  142. attention_mask: Optional[torch.Tensor] = None,
  143. layer_head_mask: Optional[torch.Tensor] = None,
  144. output_attentions: bool = False,
  145. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  146. """Input shape: Batch x Time x Channel"""
  147. # if key_value_states are provided this layer is used as a cross-attention layer
  148. # for the decoder
  149. is_cross_attention = key_value_states is not None
  150. bsz, tgt_len, _ = hidden_states.size()
  151. # get query proj
  152. query_states = self.q_proj(hidden_states) * self.scaling
  153. # get key, value proj
  154. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  155. # is checking that the `sequence_length` of the `past_key_value` is the same as
  156. # the provided `key_value_states` to support prefix tuning
  157. if (
  158. is_cross_attention
  159. and past_key_value is not None
  160. and past_key_value[0].shape[2] == key_value_states.shape[1]
  161. ):
  162. # reuse k,v, cross_attentions
  163. key_states = past_key_value[0]
  164. value_states = past_key_value[1]
  165. elif is_cross_attention:
  166. # cross_attentions
  167. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  168. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  169. elif past_key_value is not None:
  170. # reuse k, v, self_attention
  171. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  172. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  173. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  174. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  175. else:
  176. # self_attention
  177. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  178. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  179. if self.is_decoder:
  180. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  181. # Further calls to cross_attention layer can then reuse all cross-attention
  182. # key/value_states (first "if" case)
  183. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  184. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  185. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  186. # if encoder bi-directional self-attention `past_key_value` is always `None`
  187. past_key_value = (key_states, value_states)
  188. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  189. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  190. key_states = key_states.reshape(*proj_shape)
  191. value_states = value_states.reshape(*proj_shape)
  192. src_len = key_states.size(1)
  193. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  194. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  195. raise ValueError(
  196. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  197. f" {attn_weights.size()}"
  198. )
  199. if attention_mask is not None:
  200. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  201. raise ValueError(
  202. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  203. )
  204. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  205. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  206. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  207. if layer_head_mask is not None:
  208. if layer_head_mask.size() != (self.num_heads,):
  209. raise ValueError(
  210. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  211. f" {layer_head_mask.size()}"
  212. )
  213. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  214. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  215. if output_attentions:
  216. # this operation is a bit awkward, but it's required to
  217. # make sure that attn_weights keeps its gradient.
  218. # In order to do so, attn_weights have to be reshaped
  219. # twice and have to be reused in the following
  220. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  221. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  222. else:
  223. attn_weights_reshaped = None
  224. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  225. attn_output = torch.bmm(attn_probs, value_states)
  226. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  227. raise ValueError(
  228. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  229. f" {attn_output.size()}"
  230. )
  231. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  232. attn_output = attn_output.transpose(1, 2)
  233. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  234. # partitioned across GPUs when using tensor-parallelism.
  235. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  236. attn_output = self.out_proj(attn_output)
  237. return attn_output, attn_weights_reshaped, past_key_value
  238. class BartFlashAttention2(BartAttention):
  239. """
  240. Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays
  241. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  242. flash attention and deal with padding tokens in case the input contains any of them.
  243. """
  244. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  245. def __init__(self, *args, **kwargs):
  246. super().__init__(*args, **kwargs)
  247. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  248. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  249. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  250. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  251. def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  252. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
  253. def forward(
  254. self,
  255. hidden_states: torch.Tensor,
  256. key_value_states: Optional[torch.Tensor] = None,
  257. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  258. attention_mask: Optional[torch.Tensor] = None,
  259. layer_head_mask: Optional[torch.Tensor] = None,
  260. output_attentions: bool = False,
  261. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  262. # BartFlashAttention2 attention does not support output_attentions
  263. if output_attentions:
  264. raise ValueError("BartFlashAttention2 attention does not support output_attentions")
  265. # if key_value_states are provided this layer is used as a cross-attention layer
  266. # for the decoder
  267. is_cross_attention = key_value_states is not None
  268. bsz, q_len, _ = hidden_states.size()
  269. # get query proj
  270. query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
  271. # get key, value proj
  272. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  273. # is checking that the `sequence_length` of the `past_key_value` is the same as
  274. # the provided `key_value_states` to support prefix tuning
  275. if (
  276. is_cross_attention
  277. and past_key_value is not None
  278. and past_key_value[0].shape[2] == key_value_states.shape[1]
  279. ):
  280. # reuse k,v, cross_attentions
  281. key_states = past_key_value[0].transpose(1, 2)
  282. value_states = past_key_value[1].transpose(1, 2)
  283. elif is_cross_attention:
  284. # cross_attentions
  285. key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
  286. value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
  287. elif past_key_value is not None:
  288. # reuse k, v, self_attention
  289. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  290. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  291. key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
  292. value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
  293. else:
  294. # self_attention
  295. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  296. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  297. if self.is_decoder:
  298. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  299. # Further calls to cross_attention layer can then reuse all cross-attention
  300. # key/value_states (first "if" case)
  301. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  302. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  303. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  304. # if encoder bi-directional self-attention `past_key_value` is always `None`
  305. past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
  306. kv_seq_len = key_states.shape[-2]
  307. if past_key_value is not None:
  308. kv_seq_len += past_key_value[0].shape[-2]
  309. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  310. # therefore the input hidden states gets silently casted in float32. Hence, we need
  311. # cast them back in the correct dtype just to be sure everything works as expected.
  312. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  313. # in fp32. (LlamaRMSNorm handles it correctly)
  314. input_dtype = query_states.dtype
  315. if input_dtype == torch.float32:
  316. if torch.is_autocast_enabled():
  317. target_dtype = torch.get_autocast_gpu_dtype()
  318. # Handle the case where the model is quantized
  319. elif hasattr(self.config, "_pre_quantization_dtype"):
  320. target_dtype = self.config._pre_quantization_dtype
  321. else:
  322. target_dtype = self.q_proj.weight.dtype
  323. logger.warning_once(
  324. f"The input hidden states seems to be silently casted in float32, this might be related to"
  325. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  326. f" {target_dtype}."
  327. )
  328. query_states = query_states.to(target_dtype)
  329. key_states = key_states.to(target_dtype)
  330. value_states = value_states.to(target_dtype)
  331. attn_output = _flash_attention_forward(
  332. query_states,
  333. key_states,
  334. value_states,
  335. attention_mask,
  336. q_len,
  337. dropout=self.dropout if self.training else 0.0,
  338. is_causal=self.is_causal,
  339. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  340. )
  341. attn_output = attn_output.reshape(bsz, q_len, -1)
  342. attn_output = self.out_proj(attn_output)
  343. if not output_attentions:
  344. attn_weights = None
  345. return attn_output, attn_weights, past_key_value
  346. class BartSdpaAttention(BartAttention):
  347. def forward(
  348. self,
  349. hidden_states: torch.Tensor,
  350. key_value_states: Optional[torch.Tensor] = None,
  351. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  352. attention_mask: Optional[torch.Tensor] = None,
  353. layer_head_mask: Optional[torch.Tensor] = None,
  354. output_attentions: bool = False,
  355. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  356. """Input shape: Batch x Time x Channel"""
  357. if output_attentions or layer_head_mask is not None:
  358. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
  359. logger.warning_once(
  360. "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
  361. ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  362. )
  363. return super().forward(
  364. hidden_states,
  365. key_value_states=key_value_states,
  366. past_key_value=past_key_value,
  367. attention_mask=attention_mask,
  368. layer_head_mask=layer_head_mask,
  369. output_attentions=output_attentions,
  370. )
  371. # if key_value_states are provided this layer is used as a cross-attention layer
  372. # for the decoder
  373. is_cross_attention = key_value_states is not None
  374. bsz, tgt_len, _ = hidden_states.size()
  375. # get query proj
  376. query_states = self.q_proj(hidden_states)
  377. # get key, value proj
  378. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  379. # is checking that the `sequence_length` of the `past_key_value` is the same as
  380. # the provided `key_value_states` to support prefix tuning
  381. if (
  382. is_cross_attention
  383. and past_key_value is not None
  384. and past_key_value[0].shape[2] == key_value_states.shape[1]
  385. ):
  386. # reuse k,v, cross_attentions
  387. key_states = past_key_value[0]
  388. value_states = past_key_value[1]
  389. elif is_cross_attention:
  390. # cross_attentions
  391. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  392. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  393. elif past_key_value is not None:
  394. # reuse k, v, self_attention
  395. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  396. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  397. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  398. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  399. else:
  400. # self_attention
  401. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  402. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  403. if self.is_decoder:
  404. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  405. # Further calls to cross_attention layer can then reuse all cross-attention
  406. # key/value_states (first "if" case)
  407. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  408. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  409. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  410. # if encoder bi-directional self-attention `past_key_value` is always `None`
  411. past_key_value = (key_states, value_states)
  412. query_states = self._shape(query_states, tgt_len, bsz)
  413. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  414. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  415. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
  416. is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
  417. # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
  418. # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
  419. attn_output = torch.nn.functional.scaled_dot_product_attention(
  420. query_states,
  421. key_states,
  422. value_states,
  423. attn_mask=attention_mask,
  424. dropout_p=self.dropout if self.training else 0.0,
  425. is_causal=is_causal,
  426. )
  427. if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
  428. raise ValueError(
  429. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  430. f" {attn_output.size()}"
  431. )
  432. attn_output = attn_output.transpose(1, 2)
  433. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  434. # partitioned across GPUs when using tensor-parallelism.
  435. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  436. attn_output = self.out_proj(attn_output)
  437. return attn_output, None, past_key_value
  438. BART_ATTENTION_CLASSES = {
  439. "eager": BartAttention,
  440. "sdpa": BartSdpaAttention,
  441. "flash_attention_2": BartFlashAttention2,
  442. }
  443. class BartEncoderLayer(nn.Module):
  444. def __init__(self, config: BartConfig):
  445. super().__init__()
  446. self.embed_dim = config.d_model
  447. self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
  448. embed_dim=self.embed_dim,
  449. num_heads=config.encoder_attention_heads,
  450. dropout=config.attention_dropout,
  451. config=config,
  452. )
  453. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  454. self.dropout = config.dropout
  455. self.activation_fn = ACT2FN[config.activation_function]
  456. self.activation_dropout = config.activation_dropout
  457. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  458. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  459. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  460. def forward(
  461. self,
  462. hidden_states: torch.FloatTensor,
  463. attention_mask: torch.FloatTensor,
  464. layer_head_mask: torch.FloatTensor,
  465. output_attentions: Optional[bool] = False,
  466. ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  467. """
  468. Args:
  469. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  470. attention_mask (`torch.FloatTensor`): attention mask of size
  471. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  472. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  473. `(encoder_attention_heads,)`.
  474. output_attentions (`bool`, *optional*):
  475. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  476. returned tensors for more detail.
  477. """
  478. residual = hidden_states
  479. hidden_states, attn_weights, _ = self.self_attn(
  480. hidden_states=hidden_states,
  481. attention_mask=attention_mask,
  482. layer_head_mask=layer_head_mask,
  483. output_attentions=output_attentions,
  484. )
  485. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  486. hidden_states = residual + hidden_states
  487. hidden_states = self.self_attn_layer_norm(hidden_states)
  488. residual = hidden_states
  489. hidden_states = self.activation_fn(self.fc1(hidden_states))
  490. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  491. hidden_states = self.fc2(hidden_states)
  492. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  493. hidden_states = residual + hidden_states
  494. hidden_states = self.final_layer_norm(hidden_states)
  495. if hidden_states.dtype == torch.float16 and (
  496. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  497. ):
  498. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  499. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  500. outputs = (hidden_states,)
  501. if output_attentions:
  502. outputs += (attn_weights,)
  503. return outputs
  504. class BartDecoderLayer(nn.Module):
  505. def __init__(self, config: BartConfig):
  506. super().__init__()
  507. self.embed_dim = config.d_model
  508. self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
  509. embed_dim=self.embed_dim,
  510. num_heads=config.decoder_attention_heads,
  511. dropout=config.attention_dropout,
  512. is_decoder=True,
  513. is_causal=True,
  514. config=config,
  515. )
  516. self.dropout = config.dropout
  517. self.activation_fn = ACT2FN[config.activation_function]
  518. self.activation_dropout = config.activation_dropout
  519. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  520. self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
  521. self.embed_dim,
  522. config.decoder_attention_heads,
  523. dropout=config.attention_dropout,
  524. is_decoder=True,
  525. config=config,
  526. )
  527. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  528. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  529. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  530. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  531. def forward(
  532. self,
  533. hidden_states: torch.Tensor,
  534. attention_mask: Optional[torch.Tensor] = None,
  535. encoder_hidden_states: Optional[torch.Tensor] = None,
  536. encoder_attention_mask: Optional[torch.Tensor] = None,
  537. layer_head_mask: Optional[torch.Tensor] = None,
  538. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  539. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  540. output_attentions: Optional[bool] = False,
  541. use_cache: Optional[bool] = True,
  542. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  543. """
  544. Args:
  545. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  546. attention_mask (`torch.FloatTensor`): attention mask of size
  547. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  548. encoder_hidden_states (`torch.FloatTensor`):
  549. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  550. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  551. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  552. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  553. `(encoder_attention_heads,)`.
  554. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  555. size `(decoder_attention_heads,)`.
  556. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
  557. output_attentions (`bool`, *optional*):
  558. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  559. returned tensors for more detail.
  560. """
  561. residual = hidden_states
  562. # Self Attention
  563. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  564. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  565. # add present self-attn cache to positions 1,2 of present_key_value tuple
  566. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  567. hidden_states=hidden_states,
  568. past_key_value=self_attn_past_key_value,
  569. attention_mask=attention_mask,
  570. layer_head_mask=layer_head_mask,
  571. output_attentions=output_attentions,
  572. )
  573. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  574. hidden_states = residual + hidden_states
  575. hidden_states = self.self_attn_layer_norm(hidden_states)
  576. # Cross-Attention Block
  577. cross_attn_present_key_value = None
  578. cross_attn_weights = None
  579. if encoder_hidden_states is not None:
  580. residual = hidden_states
  581. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  582. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  583. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  584. hidden_states=hidden_states,
  585. key_value_states=encoder_hidden_states,
  586. attention_mask=encoder_attention_mask,
  587. layer_head_mask=cross_attn_layer_head_mask,
  588. past_key_value=cross_attn_past_key_value,
  589. output_attentions=output_attentions,
  590. )
  591. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  592. hidden_states = residual + hidden_states
  593. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  594. # add cross-attn to positions 3,4 of present_key_value tuple
  595. present_key_value = present_key_value + cross_attn_present_key_value
  596. # Fully Connected
  597. residual = hidden_states
  598. hidden_states = self.activation_fn(self.fc1(hidden_states))
  599. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  600. hidden_states = self.fc2(hidden_states)
  601. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  602. hidden_states = residual + hidden_states
  603. hidden_states = self.final_layer_norm(hidden_states)
  604. outputs = (hidden_states,)
  605. if output_attentions:
  606. outputs += (self_attn_weights, cross_attn_weights)
  607. if use_cache:
  608. outputs += (present_key_value,)
  609. return outputs
  610. class BartClassificationHead(nn.Module):
  611. """Head for sentence-level classification tasks."""
  612. def __init__(
  613. self,
  614. input_dim: int,
  615. inner_dim: int,
  616. num_classes: int,
  617. pooler_dropout: float,
  618. ):
  619. super().__init__()
  620. self.dense = nn.Linear(input_dim, inner_dim)
  621. self.dropout = nn.Dropout(p=pooler_dropout)
  622. self.out_proj = nn.Linear(inner_dim, num_classes)
  623. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  624. hidden_states = self.dropout(hidden_states)
  625. hidden_states = self.dense(hidden_states)
  626. hidden_states = torch.tanh(hidden_states)
  627. hidden_states = self.dropout(hidden_states)
  628. hidden_states = self.out_proj(hidden_states)
  629. return hidden_states
  630. class BartPreTrainedModel(PreTrainedModel):
  631. config_class = BartConfig
  632. base_model_prefix = "model"
  633. supports_gradient_checkpointing = True
  634. _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
  635. _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
  636. _skip_keys_device_placement = "past_key_values"
  637. _supports_flash_attn_2 = True
  638. _supports_sdpa = True
  639. def _init_weights(self, module):
  640. std = self.config.init_std
  641. if isinstance(module, nn.Linear):
  642. module.weight.data.normal_(mean=0.0, std=std)
  643. if module.bias is not None:
  644. module.bias.data.zero_()
  645. elif isinstance(module, nn.Embedding):
  646. module.weight.data.normal_(mean=0.0, std=std)
  647. if module.padding_idx is not None:
  648. module.weight.data[module.padding_idx].zero_()
  649. @property
  650. def dummy_inputs(self):
  651. pad_token = self.config.pad_token_id
  652. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  653. dummy_inputs = {
  654. "attention_mask": input_ids.ne(pad_token),
  655. "input_ids": input_ids,
  656. }
  657. return dummy_inputs
  658. class PretrainedBartModel(BartPreTrainedModel):
  659. def __init_subclass__(self):
  660. warnings.warn(
  661. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  662. FutureWarning,
  663. )
  664. class BartPretrainedModel(BartPreTrainedModel):
  665. def __init_subclass__(self):
  666. warnings.warn(
  667. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  668. FutureWarning,
  669. )
  670. BART_START_DOCSTRING = r"""
  671. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  672. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  673. etc.)
  674. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  675. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  676. and behavior.
  677. Parameters:
  678. config ([`BartConfig`]):
  679. Model configuration class with all the parameters of the model. Initializing with a config file does not
  680. load the weights associated with the model, only the configuration. Check out the
  681. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  682. """
  683. BART_GENERATION_EXAMPLE = r"""
  684. Summarization example:
  685. ```python
  686. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  687. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  688. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  689. >>> ARTICLE_TO_SUMMARIZE = (
  690. ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  691. ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  692. ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  693. ... )
  694. >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
  695. >>> # Generate Summary
  696. >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
  697. >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  698. 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
  699. ```
  700. Mask filling example:
  701. ```python
  702. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  703. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  704. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
  705. >>> TXT = "My friends are <mask> but they eat too many carbs."
  706. >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
  707. >>> logits = model(input_ids).logits
  708. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  709. >>> probs = logits[0, masked_index].softmax(dim=0)
  710. >>> values, predictions = probs.topk(5)
  711. >>> tokenizer.decode(predictions).split()
  712. ['not', 'good', 'healthy', 'great', 'very']
  713. ```
  714. """
  715. BART_INPUTS_DOCSTRING = r"""
  716. Args:
  717. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  718. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  719. it.
  720. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  721. [`PreTrainedTokenizer.__call__`] for details.
  722. [What are input IDs?](../glossary#input-ids)
  723. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  724. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  725. - 1 for tokens that are **not masked**,
  726. - 0 for tokens that are **masked**.
  727. [What are attention masks?](../glossary#attention-mask)
  728. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  729. Indices of decoder input sequence tokens in the vocabulary.
  730. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  731. [`PreTrainedTokenizer.__call__`] for details.
  732. [What are decoder input IDs?](../glossary#decoder-input-ids)
  733. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  734. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  735. For translation and summarization training, `decoder_input_ids` should be provided. If no
  736. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  737. for denoising pre-training following the paper.
  738. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  739. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  740. be used by default.
  741. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  742. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  743. information on the default strategy.
  744. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  745. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  746. - 1 indicates the head is **not masked**,
  747. - 0 indicates the head is **masked**.
  748. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  749. Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
  750. - 1 indicates the head is **not masked**,
  751. - 0 indicates the head is **masked**.
  752. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  753. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  754. 1]`:
  755. - 1 indicates the head is **not masked**,
  756. - 0 indicates the head is **masked**.
  757. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  758. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  759. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  760. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  761. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  762. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  763. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  764. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  765. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  766. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  767. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  768. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  769. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  770. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  771. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  772. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  773. than the model's internal embedding lookup matrix.
  774. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  775. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  776. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  777. input (see `past_key_values`). This is useful if you want more control over how to convert
  778. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  779. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  780. of `inputs_embeds`.
  781. use_cache (`bool`, *optional*):
  782. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  783. `past_key_values`).
  784. output_attentions (`bool`, *optional*):
  785. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  786. tensors for more detail.
  787. output_hidden_states (`bool`, *optional*):
  788. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  789. more detail.
  790. return_dict (`bool`, *optional*):
  791. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  792. """
  793. class BartEncoder(BartPreTrainedModel):
  794. """
  795. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  796. [`BartEncoderLayer`].
  797. Args:
  798. config: BartConfig
  799. embed_tokens (nn.Embedding): output embedding
  800. """
  801. def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
  802. super().__init__(config)
  803. self.dropout = config.dropout
  804. self.layerdrop = config.encoder_layerdrop
  805. embed_dim = config.d_model
  806. self.padding_idx = config.pad_token_id
  807. self.max_source_positions = config.max_position_embeddings
  808. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  809. self.embed_tokens = BartScaledWordEmbedding(
  810. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  811. )
  812. if embed_tokens is not None:
  813. self.embed_tokens.weight = embed_tokens.weight
  814. self.embed_positions = BartLearnedPositionalEmbedding(
  815. config.max_position_embeddings,
  816. embed_dim,
  817. )
  818. self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
  819. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  820. self._use_sdpa = config._attn_implementation == "sdpa"
  821. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  822. self.gradient_checkpointing = False
  823. # Initialize weights and apply final processing
  824. self.post_init()
  825. def get_input_embeddings(self):
  826. return self.embed_tokens
  827. def set_input_embeddings(self, value):
  828. self.embed_tokens = value
  829. def forward(
  830. self,
  831. input_ids: torch.LongTensor = None,
  832. attention_mask: Optional[torch.Tensor] = None,
  833. head_mask: Optional[torch.Tensor] = None,
  834. inputs_embeds: Optional[torch.FloatTensor] = None,
  835. output_attentions: Optional[bool] = None,
  836. output_hidden_states: Optional[bool] = None,
  837. return_dict: Optional[bool] = None,
  838. ) -> Union[Tuple, BaseModelOutput]:
  839. r"""
  840. Args:
  841. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  842. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  843. provide it.
  844. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  845. [`PreTrainedTokenizer.__call__`] for details.
  846. [What are input IDs?](../glossary#input-ids)
  847. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  848. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  849. - 1 for tokens that are **not masked**,
  850. - 0 for tokens that are **masked**.
  851. [What are attention masks?](../glossary#attention-mask)
  852. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  853. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  854. - 1 indicates the head is **not masked**,
  855. - 0 indicates the head is **masked**.
  856. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  857. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  858. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  859. than the model's internal embedding lookup matrix.
  860. output_attentions (`bool`, *optional*):
  861. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  862. returned tensors for more detail.
  863. output_hidden_states (`bool`, *optional*):
  864. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  865. for more detail.
  866. return_dict (`bool`, *optional*):
  867. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  868. """
  869. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  870. output_hidden_states = (
  871. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  872. )
  873. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  874. # retrieve input_ids and inputs_embeds
  875. if input_ids is not None and inputs_embeds is not None:
  876. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  877. elif input_ids is not None:
  878. input = input_ids
  879. input_ids = input_ids.view(-1, input_ids.shape[-1])
  880. elif inputs_embeds is not None:
  881. input = inputs_embeds[:, :, -1]
  882. else:
  883. raise ValueError("You have to specify either input_ids or inputs_embeds")
  884. if inputs_embeds is None:
  885. inputs_embeds = self.embed_tokens(input_ids)
  886. embed_pos = self.embed_positions(input)
  887. embed_pos = embed_pos.to(inputs_embeds.device)
  888. hidden_states = inputs_embeds + embed_pos
  889. hidden_states = self.layernorm_embedding(hidden_states)
  890. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  891. # expand attention_mask
  892. if attention_mask is not None:
  893. if self._use_flash_attention_2:
  894. attention_mask = attention_mask if 0 in attention_mask else None
  895. elif self._use_sdpa and head_mask is None and not output_attentions:
  896. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  897. # the manual implementation that requires a 4D causal mask in all cases.
  898. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  899. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  900. else:
  901. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  902. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  903. encoder_states = () if output_hidden_states else None
  904. all_attentions = () if output_attentions else None
  905. # check if head_mask has a correct number of layers specified if desired
  906. if head_mask is not None:
  907. if head_mask.size()[0] != (len(self.layers)):
  908. raise ValueError(
  909. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  910. f" {head_mask.size()[0]}."
  911. )
  912. for idx, encoder_layer in enumerate(self.layers):
  913. if output_hidden_states:
  914. encoder_states = encoder_states + (hidden_states,)
  915. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  916. to_drop = False
  917. if self.training:
  918. dropout_probability = torch.rand([])
  919. if dropout_probability < self.layerdrop: # skip the layer
  920. to_drop = True
  921. if to_drop:
  922. layer_outputs = (None, None)
  923. else:
  924. if self.gradient_checkpointing and self.training:
  925. layer_outputs = self._gradient_checkpointing_func(
  926. encoder_layer.__call__,
  927. hidden_states,
  928. attention_mask,
  929. (head_mask[idx] if head_mask is not None else None),
  930. output_attentions,
  931. )
  932. else:
  933. layer_outputs = encoder_layer(
  934. hidden_states,
  935. attention_mask,
  936. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  937. output_attentions=output_attentions,
  938. )
  939. hidden_states = layer_outputs[0]
  940. if output_attentions:
  941. all_attentions = all_attentions + (layer_outputs[1],)
  942. if output_hidden_states:
  943. encoder_states = encoder_states + (hidden_states,)
  944. if not return_dict:
  945. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  946. return BaseModelOutput(
  947. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  948. )
  949. class BartDecoder(BartPreTrainedModel):
  950. """
  951. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
  952. Args:
  953. config: BartConfig
  954. embed_tokens (nn.Embedding): output embedding
  955. """
  956. def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
  957. super().__init__(config)
  958. self.dropout = config.dropout
  959. self.layerdrop = config.decoder_layerdrop
  960. self.padding_idx = config.pad_token_id
  961. self.max_target_positions = config.max_position_embeddings
  962. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  963. self.embed_tokens = BartScaledWordEmbedding(
  964. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  965. )
  966. if embed_tokens is not None:
  967. self.embed_tokens.weight = embed_tokens.weight
  968. self.embed_positions = BartLearnedPositionalEmbedding(
  969. config.max_position_embeddings,
  970. config.d_model,
  971. )
  972. self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
  973. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  974. self._use_sdpa = config._attn_implementation == "sdpa"
  975. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  976. self.gradient_checkpointing = False
  977. # Initialize weights and apply final processing
  978. self.post_init()
  979. def get_input_embeddings(self):
  980. return self.embed_tokens
  981. def set_input_embeddings(self, value):
  982. self.embed_tokens = value
  983. def forward(
  984. self,
  985. input_ids: torch.LongTensor = None,
  986. attention_mask: Optional[torch.Tensor] = None,
  987. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  988. encoder_attention_mask: Optional[torch.LongTensor] = None,
  989. head_mask: Optional[torch.Tensor] = None,
  990. cross_attn_head_mask: Optional[torch.Tensor] = None,
  991. past_key_values: Optional[List[torch.FloatTensor]] = None,
  992. inputs_embeds: Optional[torch.FloatTensor] = None,
  993. use_cache: Optional[bool] = None,
  994. output_attentions: Optional[bool] = None,
  995. output_hidden_states: Optional[bool] = None,
  996. return_dict: Optional[bool] = None,
  997. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  998. r"""
  999. Args:
  1000. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1001. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1002. provide it.
  1003. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1004. [`PreTrainedTokenizer.__call__`] for details.
  1005. [What are input IDs?](../glossary#input-ids)
  1006. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1007. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1008. - 1 for tokens that are **not masked**,
  1009. - 0 for tokens that are **masked**.
  1010. [What are attention masks?](../glossary#attention-mask)
  1011. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  1012. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1013. of the decoder.
  1014. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  1015. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  1016. selected in `[0, 1]`:
  1017. - 1 for tokens that are **not masked**,
  1018. - 0 for tokens that are **masked**.
  1019. [What are attention masks?](../glossary#attention-mask)
  1020. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1021. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1022. - 1 indicates the head is **not masked**,
  1023. - 0 indicates the head is **masked**.
  1024. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1025. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  1026. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  1027. - 1 indicates the head is **not masked**,
  1028. - 0 indicates the head is **masked**.
  1029. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1030. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  1031. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  1032. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  1033. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1034. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1035. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1036. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1037. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1038. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1039. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1040. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1041. than the model's internal embedding lookup matrix.
  1042. output_attentions (`bool`, *optional*):
  1043. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1044. returned tensors for more detail.
  1045. output_hidden_states (`bool`, *optional*):
  1046. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1047. for more detail.
  1048. return_dict (`bool`, *optional*):
  1049. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1050. """
  1051. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1052. output_hidden_states = (
  1053. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1054. )
  1055. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1056. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1057. # retrieve input_ids and inputs_embeds
  1058. if input_ids is not None and inputs_embeds is not None:
  1059. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  1060. elif input_ids is not None:
  1061. input = input_ids
  1062. input_shape = input.shape
  1063. input_ids = input_ids.view(-1, input_shape[-1])
  1064. elif inputs_embeds is not None:
  1065. input_shape = inputs_embeds.size()[:-1]
  1066. input = inputs_embeds[:, :, -1]
  1067. else:
  1068. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  1069. # past_key_values_length
  1070. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  1071. if inputs_embeds is None:
  1072. inputs_embeds = self.embed_tokens(input)
  1073. if self._use_flash_attention_2:
  1074. # 2d mask is passed through the layers
  1075. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  1076. elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
  1077. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  1078. # the manual implementation that requires a 4D causal mask in all cases.
  1079. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  1080. attention_mask,
  1081. input_shape,
  1082. inputs_embeds,
  1083. past_key_values_length,
  1084. )
  1085. else:
  1086. # 4d mask is passed through the layers
  1087. attention_mask = _prepare_4d_causal_attention_mask(
  1088. attention_mask, input_shape, inputs_embeds, past_key_values_length
  1089. )
  1090. # expand encoder attention mask
  1091. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1092. if self._use_flash_attention_2:
  1093. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  1094. elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
  1095. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  1096. # the manual implementation that requires a 4D causal mask in all cases.
  1097. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1098. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  1099. encoder_attention_mask,
  1100. inputs_embeds.dtype,
  1101. tgt_len=input_shape[-1],
  1102. )
  1103. else:
  1104. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1105. encoder_attention_mask = _prepare_4d_attention_mask(
  1106. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  1107. )
  1108. # embed positions
  1109. positions = self.embed_positions(input, past_key_values_length)
  1110. positions = positions.to(inputs_embeds.device)
  1111. hidden_states = inputs_embeds + positions
  1112. hidden_states = self.layernorm_embedding(hidden_states)
  1113. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1114. if self.gradient_checkpointing and self.training:
  1115. if use_cache:
  1116. logger.warning_once(
  1117. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1118. )
  1119. use_cache = False
  1120. # decoder layers
  1121. all_hidden_states = () if output_hidden_states else None
  1122. all_self_attns = () if output_attentions else None
  1123. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  1124. next_decoder_cache = () if use_cache else None
  1125. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1126. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  1127. if attn_mask is not None:
  1128. if attn_mask.size()[0] != (len(self.layers)):
  1129. raise ValueError(
  1130. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1131. f" {head_mask.size()[0]}."
  1132. )
  1133. for idx, decoder_layer in enumerate(self.layers):
  1134. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  1135. if output_hidden_states:
  1136. all_hidden_states += (hidden_states,)
  1137. if self.training:
  1138. dropout_probability = torch.rand([])
  1139. if dropout_probability < self.layerdrop:
  1140. continue
  1141. past_key_value = past_key_values[idx] if past_key_values is not None else None
  1142. if self.gradient_checkpointing and self.training:
  1143. layer_outputs = self._gradient_checkpointing_func(
  1144. decoder_layer.__call__,
  1145. hidden_states,
  1146. attention_mask,
  1147. encoder_hidden_states,
  1148. encoder_attention_mask,
  1149. head_mask[idx] if head_mask is not None else None,
  1150. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
  1151. None,
  1152. output_attentions,
  1153. use_cache,
  1154. )
  1155. else:
  1156. layer_outputs = decoder_layer(
  1157. hidden_states,
  1158. attention_mask=attention_mask,
  1159. encoder_hidden_states=encoder_hidden_states,
  1160. encoder_attention_mask=encoder_attention_mask,
  1161. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1162. cross_attn_layer_head_mask=(
  1163. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  1164. ),
  1165. past_key_value=past_key_value,
  1166. output_attentions=output_attentions,
  1167. use_cache=use_cache,
  1168. )
  1169. hidden_states = layer_outputs[0]
  1170. if use_cache:
  1171. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  1172. if output_attentions:
  1173. all_self_attns += (layer_outputs[1],)
  1174. if encoder_hidden_states is not None:
  1175. all_cross_attentions += (layer_outputs[2],)
  1176. # add hidden states from the last decoder layer
  1177. if output_hidden_states:
  1178. all_hidden_states += (hidden_states,)
  1179. next_cache = next_decoder_cache if use_cache else None
  1180. if not return_dict:
  1181. return tuple(
  1182. v
  1183. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
  1184. if v is not None
  1185. )
  1186. return BaseModelOutputWithPastAndCrossAttentions(
  1187. last_hidden_state=hidden_states,
  1188. past_key_values=next_cache,
  1189. hidden_states=all_hidden_states,
  1190. attentions=all_self_attns,
  1191. cross_attentions=all_cross_attentions,
  1192. )
  1193. @add_start_docstrings(
  1194. "The bare BART Model outputting raw hidden-states without any specific head on top.",
  1195. BART_START_DOCSTRING,
  1196. )
  1197. class BartModel(BartPreTrainedModel):
  1198. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1199. def __init__(self, config: BartConfig):
  1200. super().__init__(config)
  1201. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  1202. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  1203. self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  1204. self.encoder = BartEncoder(config, self.shared)
  1205. self.decoder = BartDecoder(config, self.shared)
  1206. # Initialize weights and apply final processing
  1207. self.post_init()
  1208. def _tie_weights(self):
  1209. if self.config.tie_word_embeddings:
  1210. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1211. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1212. def get_input_embeddings(self):
  1213. return self.shared
  1214. def set_input_embeddings(self, value):
  1215. self.shared = value
  1216. self.encoder.embed_tokens = self.shared
  1217. self.decoder.embed_tokens = self.shared
  1218. def get_encoder(self):
  1219. return self.encoder
  1220. def get_decoder(self):
  1221. return self.decoder
  1222. @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
  1223. @add_code_sample_docstrings(
  1224. checkpoint=_CHECKPOINT_FOR_DOC,
  1225. output_type=Seq2SeqModelOutput,
  1226. config_class=_CONFIG_FOR_DOC,
  1227. expected_output=_EXPECTED_OUTPUT_SHAPE,
  1228. )
  1229. def forward(
  1230. self,
  1231. input_ids: torch.LongTensor = None,
  1232. attention_mask: Optional[torch.Tensor] = None,
  1233. decoder_input_ids: Optional[torch.LongTensor] = None,
  1234. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1235. head_mask: Optional[torch.Tensor] = None,
  1236. decoder_head_mask: Optional[torch.Tensor] = None,
  1237. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1238. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1239. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1240. inputs_embeds: Optional[torch.FloatTensor] = None,
  1241. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1242. use_cache: Optional[bool] = None,
  1243. output_attentions: Optional[bool] = None,
  1244. output_hidden_states: Optional[bool] = None,
  1245. return_dict: Optional[bool] = None,
  1246. ) -> Union[Tuple, Seq2SeqModelOutput]:
  1247. # different to other models, Bart automatically creates decoder_input_ids from
  1248. # input_ids if no decoder_input_ids are provided
  1249. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1250. if input_ids is None:
  1251. raise ValueError(
  1252. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1253. "passed, `input_ids` cannot be `None`. Please pass either "
  1254. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1255. )
  1256. decoder_input_ids = shift_tokens_right(
  1257. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  1258. )
  1259. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1260. output_hidden_states = (
  1261. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1262. )
  1263. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1264. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1265. if encoder_outputs is None:
  1266. encoder_outputs = self.encoder(
  1267. input_ids=input_ids,
  1268. attention_mask=attention_mask,
  1269. head_mask=head_mask,
  1270. inputs_embeds=inputs_embeds,
  1271. output_attentions=output_attentions,
  1272. output_hidden_states=output_hidden_states,
  1273. return_dict=return_dict,
  1274. )
  1275. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1276. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1277. encoder_outputs = BaseModelOutput(
  1278. last_hidden_state=encoder_outputs[0],
  1279. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1280. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1281. )
  1282. # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  1283. decoder_outputs = self.decoder(
  1284. input_ids=decoder_input_ids,
  1285. attention_mask=decoder_attention_mask,
  1286. encoder_hidden_states=encoder_outputs[0],
  1287. encoder_attention_mask=attention_mask,
  1288. head_mask=decoder_head_mask,
  1289. cross_attn_head_mask=cross_attn_head_mask,
  1290. past_key_values=past_key_values,
  1291. inputs_embeds=decoder_inputs_embeds,
  1292. use_cache=use_cache,
  1293. output_attentions=output_attentions,
  1294. output_hidden_states=output_hidden_states,
  1295. return_dict=return_dict,
  1296. )
  1297. if not return_dict:
  1298. return decoder_outputs + encoder_outputs
  1299. return Seq2SeqModelOutput(
  1300. last_hidden_state=decoder_outputs.last_hidden_state,
  1301. past_key_values=decoder_outputs.past_key_values,
  1302. decoder_hidden_states=decoder_outputs.hidden_states,
  1303. decoder_attentions=decoder_outputs.attentions,
  1304. cross_attentions=decoder_outputs.cross_attentions,
  1305. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1306. encoder_hidden_states=encoder_outputs.hidden_states,
  1307. encoder_attentions=encoder_outputs.attentions,
  1308. )
  1309. @add_start_docstrings(
  1310. "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
  1311. )
  1312. class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
  1313. base_model_prefix = "model"
  1314. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1315. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1316. def __init__(self, config: BartConfig):
  1317. super().__init__(config)
  1318. self.model = BartModel(config)
  1319. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1320. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1321. # Initialize weights and apply final processing
  1322. self.post_init()
  1323. def get_encoder(self):
  1324. return self.model.get_encoder()
  1325. def get_decoder(self):
  1326. return self.model.get_decoder()
  1327. def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
  1328. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
  1329. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1330. return new_embeddings
  1331. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1332. old_num_tokens = self.final_logits_bias.shape[-1]
  1333. if new_num_tokens <= old_num_tokens:
  1334. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1335. else:
  1336. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1337. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1338. self.register_buffer("final_logits_bias", new_bias)
  1339. def get_output_embeddings(self):
  1340. return self.lm_head
  1341. def set_output_embeddings(self, new_embeddings):
  1342. self.lm_head = new_embeddings
  1343. @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
  1344. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1345. @add_end_docstrings(BART_GENERATION_EXAMPLE)
  1346. def forward(
  1347. self,
  1348. input_ids: torch.LongTensor = None,
  1349. attention_mask: Optional[torch.Tensor] = None,
  1350. decoder_input_ids: Optional[torch.LongTensor] = None,
  1351. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1352. head_mask: Optional[torch.Tensor] = None,
  1353. decoder_head_mask: Optional[torch.Tensor] = None,
  1354. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1355. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1356. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1357. inputs_embeds: Optional[torch.FloatTensor] = None,
  1358. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1359. labels: Optional[torch.LongTensor] = None,
  1360. use_cache: Optional[bool] = None,
  1361. output_attentions: Optional[bool] = None,
  1362. output_hidden_states: Optional[bool] = None,
  1363. return_dict: Optional[bool] = None,
  1364. ) -> Union[Tuple, Seq2SeqLMOutput]:
  1365. r"""
  1366. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1367. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1368. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1369. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1370. Returns:
  1371. """
  1372. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1373. if labels is not None:
  1374. if use_cache:
  1375. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1376. use_cache = False
  1377. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1378. decoder_input_ids = shift_tokens_right(
  1379. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1380. )
  1381. outputs = self.model(
  1382. input_ids,
  1383. attention_mask=attention_mask,
  1384. decoder_input_ids=decoder_input_ids,
  1385. encoder_outputs=encoder_outputs,
  1386. decoder_attention_mask=decoder_attention_mask,
  1387. head_mask=head_mask,
  1388. decoder_head_mask=decoder_head_mask,
  1389. cross_attn_head_mask=cross_attn_head_mask,
  1390. past_key_values=past_key_values,
  1391. inputs_embeds=inputs_embeds,
  1392. decoder_inputs_embeds=decoder_inputs_embeds,
  1393. use_cache=use_cache,
  1394. output_attentions=output_attentions,
  1395. output_hidden_states=output_hidden_states,
  1396. return_dict=return_dict,
  1397. )
  1398. lm_logits = self.lm_head(outputs[0])
  1399. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  1400. masked_lm_loss = None
  1401. if labels is not None:
  1402. labels = labels.to(lm_logits.device)
  1403. loss_fct = CrossEntropyLoss()
  1404. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1405. if not return_dict:
  1406. output = (lm_logits,) + outputs[1:]
  1407. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1408. return Seq2SeqLMOutput(
  1409. loss=masked_lm_loss,
  1410. logits=lm_logits,
  1411. past_key_values=outputs.past_key_values,
  1412. decoder_hidden_states=outputs.decoder_hidden_states,
  1413. decoder_attentions=outputs.decoder_attentions,
  1414. cross_attentions=outputs.cross_attentions,
  1415. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1416. encoder_hidden_states=outputs.encoder_hidden_states,
  1417. encoder_attentions=outputs.encoder_attentions,
  1418. )
  1419. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1420. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1421. @staticmethod
  1422. def _reorder_cache(past_key_values, beam_idx):
  1423. reordered_past = ()
  1424. for layer_past in past_key_values:
  1425. # cached cross_attention states don't have to be reordered -> they are always the same
  1426. reordered_past += (
  1427. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
  1428. + layer_past[2:],
  1429. )
  1430. return reordered_past
  1431. @add_start_docstrings(
  1432. """
  1433. Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1434. tasks.
  1435. """,
  1436. BART_START_DOCSTRING,
  1437. )
  1438. class BartForSequenceClassification(BartPreTrainedModel):
  1439. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1440. def __init__(self, config: BartConfig, **kwargs):
  1441. super().__init__(config, **kwargs)
  1442. self.model = BartModel(config)
  1443. self.classification_head = BartClassificationHead(
  1444. config.d_model,
  1445. config.d_model,
  1446. config.num_labels,
  1447. config.classifier_dropout,
  1448. )
  1449. # Initialize weights and apply final processing
  1450. self.post_init()
  1451. @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
  1452. @add_code_sample_docstrings(
  1453. checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
  1454. output_type=Seq2SeqSequenceClassifierOutput,
  1455. config_class=_CONFIG_FOR_DOC,
  1456. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1457. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1458. )
  1459. def forward(
  1460. self,
  1461. input_ids: torch.LongTensor = None,
  1462. attention_mask: Optional[torch.Tensor] = None,
  1463. decoder_input_ids: Optional[torch.LongTensor] = None,
  1464. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1465. head_mask: Optional[torch.Tensor] = None,
  1466. decoder_head_mask: Optional[torch.Tensor] = None,
  1467. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1468. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1469. inputs_embeds: Optional[torch.FloatTensor] = None,
  1470. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1471. labels: Optional[torch.LongTensor] = None,
  1472. use_cache: Optional[bool] = None,
  1473. output_attentions: Optional[bool] = None,
  1474. output_hidden_states: Optional[bool] = None,
  1475. return_dict: Optional[bool] = None,
  1476. ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
  1477. r"""
  1478. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1479. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1480. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1481. """
  1482. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1483. if labels is not None:
  1484. use_cache = False
  1485. if input_ids is None and inputs_embeds is not None:
  1486. raise NotImplementedError(
  1487. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1488. )
  1489. outputs = self.model(
  1490. input_ids,
  1491. attention_mask=attention_mask,
  1492. decoder_input_ids=decoder_input_ids,
  1493. decoder_attention_mask=decoder_attention_mask,
  1494. head_mask=head_mask,
  1495. decoder_head_mask=decoder_head_mask,
  1496. cross_attn_head_mask=cross_attn_head_mask,
  1497. encoder_outputs=encoder_outputs,
  1498. inputs_embeds=inputs_embeds,
  1499. decoder_inputs_embeds=decoder_inputs_embeds,
  1500. use_cache=use_cache,
  1501. output_attentions=output_attentions,
  1502. output_hidden_states=output_hidden_states,
  1503. return_dict=return_dict,
  1504. )
  1505. hidden_states = outputs[0] # last hidden state
  1506. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1507. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1508. raise ValueError("All examples must have the same number of <eos> tokens.")
  1509. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1510. :, -1, :
  1511. ]
  1512. logits = self.classification_head(sentence_representation)
  1513. loss = None
  1514. if labels is not None:
  1515. labels = labels.to(logits.device)
  1516. if self.config.problem_type is None:
  1517. if self.config.num_labels == 1:
  1518. self.config.problem_type = "regression"
  1519. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1520. self.config.problem_type = "single_label_classification"
  1521. else:
  1522. self.config.problem_type = "multi_label_classification"
  1523. if self.config.problem_type == "regression":
  1524. loss_fct = MSELoss()
  1525. if self.config.num_labels == 1:
  1526. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1527. else:
  1528. loss = loss_fct(logits, labels)
  1529. elif self.config.problem_type == "single_label_classification":
  1530. loss_fct = CrossEntropyLoss()
  1531. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1532. elif self.config.problem_type == "multi_label_classification":
  1533. loss_fct = BCEWithLogitsLoss()
  1534. loss = loss_fct(logits, labels)
  1535. if not return_dict:
  1536. output = (logits,) + outputs[1:]
  1537. return ((loss,) + output) if loss is not None else output
  1538. return Seq2SeqSequenceClassifierOutput(
  1539. loss=loss,
  1540. logits=logits,
  1541. past_key_values=outputs.past_key_values,
  1542. decoder_hidden_states=outputs.decoder_hidden_states,
  1543. decoder_attentions=outputs.decoder_attentions,
  1544. cross_attentions=outputs.cross_attentions,
  1545. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1546. encoder_hidden_states=outputs.encoder_hidden_states,
  1547. encoder_attentions=outputs.encoder_attentions,
  1548. )
  1549. @add_start_docstrings(
  1550. """
  1551. BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1552. layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1553. """,
  1554. BART_START_DOCSTRING,
  1555. )
  1556. class BartForQuestionAnswering(BartPreTrainedModel):
  1557. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1558. def __init__(self, config):
  1559. super().__init__(config)
  1560. config.num_labels = 2
  1561. self.num_labels = config.num_labels
  1562. self.model = BartModel(config)
  1563. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1564. # Initialize weights and apply final processing
  1565. self.post_init()
  1566. @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
  1567. @add_code_sample_docstrings(
  1568. checkpoint=_CHECKPOINT_FOR_QA,
  1569. output_type=Seq2SeqQuestionAnsweringModelOutput,
  1570. config_class=_CONFIG_FOR_DOC,
  1571. expected_loss=_QA_EXPECTED_LOSS,
  1572. expected_output=_QA_EXPECTED_OUTPUT,
  1573. )
  1574. def forward(
  1575. self,
  1576. input_ids: torch.Tensor = None,
  1577. attention_mask: Optional[torch.Tensor] = None,
  1578. decoder_input_ids: Optional[torch.LongTensor] = None,
  1579. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1580. head_mask: Optional[torch.Tensor] = None,
  1581. decoder_head_mask: Optional[torch.Tensor] = None,
  1582. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1583. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1584. start_positions: Optional[torch.LongTensor] = None,
  1585. end_positions: Optional[torch.LongTensor] = None,
  1586. inputs_embeds: Optional[torch.FloatTensor] = None,
  1587. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1588. use_cache: Optional[bool] = None,
  1589. output_attentions: Optional[bool] = None,
  1590. output_hidden_states: Optional[bool] = None,
  1591. return_dict: Optional[bool] = None,
  1592. ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
  1593. r"""
  1594. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1595. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1596. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  1597. are not taken into account for computing the loss.
  1598. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1599. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1600. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  1601. are not taken into account for computing the loss.
  1602. """
  1603. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1604. if start_positions is not None and end_positions is not None:
  1605. use_cache = False
  1606. outputs = self.model(
  1607. input_ids,
  1608. attention_mask=attention_mask,
  1609. decoder_input_ids=decoder_input_ids,
  1610. decoder_attention_mask=decoder_attention_mask,
  1611. head_mask=head_mask,
  1612. decoder_head_mask=decoder_head_mask,
  1613. cross_attn_head_mask=cross_attn_head_mask,
  1614. encoder_outputs=encoder_outputs,
  1615. inputs_embeds=inputs_embeds,
  1616. decoder_inputs_embeds=decoder_inputs_embeds,
  1617. use_cache=use_cache,
  1618. output_attentions=output_attentions,
  1619. output_hidden_states=output_hidden_states,
  1620. return_dict=return_dict,
  1621. )
  1622. sequence_output = outputs[0]
  1623. logits = self.qa_outputs(sequence_output)
  1624. start_logits, end_logits = logits.split(1, dim=-1)
  1625. start_logits = start_logits.squeeze(-1).contiguous()
  1626. end_logits = end_logits.squeeze(-1).contiguous()
  1627. total_loss = None
  1628. if start_positions is not None and end_positions is not None:
  1629. # If we are on multi-GPU, split add a dimension
  1630. if len(start_positions.size()) > 1:
  1631. start_positions = start_positions.squeeze(-1)
  1632. if len(end_positions.size()) > 1:
  1633. end_positions = end_positions.squeeze(-1)
  1634. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1635. ignored_index = start_logits.size(1)
  1636. start_positions = start_positions.clamp(0, ignored_index)
  1637. end_positions = end_positions.clamp(0, ignored_index)
  1638. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1639. start_loss = loss_fct(start_logits, start_positions)
  1640. end_loss = loss_fct(end_logits, end_positions)
  1641. total_loss = (start_loss + end_loss) / 2
  1642. if not return_dict:
  1643. output = (
  1644. start_logits,
  1645. end_logits,
  1646. ) + outputs[1:]
  1647. return ((total_loss,) + output) if total_loss is not None else output
  1648. return Seq2SeqQuestionAnsweringModelOutput(
  1649. loss=total_loss,
  1650. start_logits=start_logits,
  1651. end_logits=end_logits,
  1652. past_key_values=outputs.past_key_values,
  1653. decoder_hidden_states=outputs.decoder_hidden_states,
  1654. decoder_attentions=outputs.decoder_attentions,
  1655. cross_attentions=outputs.cross_attentions,
  1656. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1657. encoder_hidden_states=outputs.encoder_hidden_states,
  1658. encoder_attentions=outputs.encoder_attentions,
  1659. )
  1660. class BartDecoderWrapper(BartPreTrainedModel):
  1661. """
  1662. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1663. used in combination with the [`EncoderDecoderModel`] framework.
  1664. """
  1665. def __init__(self, config):
  1666. super().__init__(config)
  1667. self.decoder = BartDecoder(config)
  1668. def forward(self, *args, **kwargs):
  1669. return self.decoder(*args, **kwargs)
  1670. @add_start_docstrings(
  1671. """
  1672. BART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1673. """,
  1674. BART_START_DOCSTRING,
  1675. )
  1676. class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
  1677. _tied_weights_keys = ["lm_head.weight"]
  1678. def __init__(self, config):
  1679. config = copy.deepcopy(config)
  1680. config.is_decoder = True
  1681. config.is_encoder_decoder = False
  1682. super().__init__(config)
  1683. self.model = BartDecoderWrapper(config)
  1684. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1685. # Initialize weights and apply final processing
  1686. self.post_init()
  1687. def get_input_embeddings(self):
  1688. return self.model.decoder.embed_tokens
  1689. def set_input_embeddings(self, value):
  1690. self.model.decoder.embed_tokens = value
  1691. def get_output_embeddings(self):
  1692. return self.lm_head
  1693. def set_output_embeddings(self, new_embeddings):
  1694. self.lm_head = new_embeddings
  1695. def set_decoder(self, decoder):
  1696. self.model.decoder = decoder
  1697. def get_decoder(self):
  1698. return self.model.decoder
  1699. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
  1700. def forward(
  1701. self,
  1702. input_ids: torch.LongTensor = None,
  1703. attention_mask: Optional[torch.Tensor] = None,
  1704. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1705. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1706. head_mask: Optional[torch.Tensor] = None,
  1707. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1708. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1709. inputs_embeds: Optional[torch.FloatTensor] = None,
  1710. labels: Optional[torch.LongTensor] = None,
  1711. use_cache: Optional[bool] = None,
  1712. output_attentions: Optional[bool] = None,
  1713. output_hidden_states: Optional[bool] = None,
  1714. return_dict: Optional[bool] = None,
  1715. ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
  1716. r"""
  1717. Args:
  1718. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1719. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1720. provide it.
  1721. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1722. [`PreTrainedTokenizer.__call__`] for details.
  1723. [What are input IDs?](../glossary#input-ids)
  1724. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1725. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1726. - 1 for tokens that are **not masked**,
  1727. - 0 for tokens that are **masked**.
  1728. [What are attention masks?](../glossary#attention-mask)
  1729. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1730. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1731. if the model is configured as a decoder.
  1732. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1733. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
  1734. in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1735. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1736. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1737. - 1 indicates the head is **not masked**,
  1738. - 0 indicates the head is **masked**.
  1739. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1740. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1741. - 1 indicates the head is **not masked**,
  1742. - 0 indicates the head is **masked**.
  1743. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1744. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  1745. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  1746. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
  1747. tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
  1748. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1749. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1750. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1751. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1752. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1753. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1754. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1755. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1756. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1757. use_cache (`bool`, *optional*):
  1758. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1759. (see `past_key_values`).
  1760. - 1 for tokens that are **not masked**,
  1761. - 0 for tokens that are **masked**.
  1762. output_attentions (`bool`, *optional*):
  1763. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1764. returned tensors for more detail.
  1765. output_hidden_states (`bool`, *optional*):
  1766. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1767. for more detail.
  1768. return_dict (`bool`, *optional*):
  1769. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1770. Returns:
  1771. Example:
  1772. ```python
  1773. >>> from transformers import AutoTokenizer, BartForCausalLM
  1774. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  1775. >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
  1776. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1777. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1778. >>> outputs = model(**inputs)
  1779. >>> logits = outputs.logits
  1780. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1781. >>> list(logits.shape) == expected_shape
  1782. True
  1783. ```"""
  1784. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1785. output_hidden_states = (
  1786. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1787. )
  1788. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1789. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1790. outputs = self.model.decoder(
  1791. input_ids=input_ids,
  1792. attention_mask=attention_mask,
  1793. encoder_hidden_states=encoder_hidden_states,
  1794. encoder_attention_mask=encoder_attention_mask,
  1795. head_mask=head_mask,
  1796. cross_attn_head_mask=cross_attn_head_mask,
  1797. past_key_values=past_key_values,
  1798. inputs_embeds=inputs_embeds,
  1799. use_cache=use_cache,
  1800. output_attentions=output_attentions,
  1801. output_hidden_states=output_hidden_states,
  1802. return_dict=return_dict,
  1803. )
  1804. logits = self.lm_head(outputs[0])
  1805. loss = None
  1806. if labels is not None:
  1807. labels = labels.to(logits.device)
  1808. loss_fct = CrossEntropyLoss()
  1809. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1810. if not return_dict:
  1811. output = (logits,) + outputs[1:]
  1812. return (loss,) + output if loss is not None else output
  1813. return CausalLMOutputWithCrossAttentions(
  1814. loss=loss,
  1815. logits=logits,
  1816. past_key_values=outputs.past_key_values,
  1817. hidden_states=outputs.hidden_states,
  1818. attentions=outputs.attentions,
  1819. cross_attentions=outputs.cross_attentions,
  1820. )
  1821. @staticmethod
  1822. def _reorder_cache(past_key_values, beam_idx):
  1823. reordered_past = ()
  1824. for layer_past in past_key_values:
  1825. reordered_past += (
  1826. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1827. )
  1828. return reordered_past