modeling_blenderbot.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549
  1. # coding=utf-8
  2. # Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Blenderbot model."""
  16. import copy
  17. import math
  18. import os
  19. import warnings
  20. from typing import List, Optional, Tuple, Union
  21. import torch
  22. import torch.utils.checkpoint
  23. from torch import nn
  24. from torch.nn import CrossEntropyLoss
  25. from ...activations import ACT2FN
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. CausalLMOutputWithCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import (
  37. add_end_docstrings,
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. logging,
  41. replace_return_docstrings,
  42. )
  43. from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel
  44. from .configuration_blenderbot import BlenderbotConfig
  45. logger = logging.get_logger(__name__)
  46. _CONFIG_FOR_DOC = "BlenderbotConfig"
  47. _CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill"
  48. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  49. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  50. """
  51. Shift input ids one token to the right.
  52. """
  53. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  54. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  55. shifted_input_ids[:, 0] = decoder_start_token_id
  56. if pad_token_id is None:
  57. raise ValueError("self.model.config.pad_token_id has to be defined.")
  58. # replace possible -100 values in labels by `pad_token_id`
  59. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  60. return shifted_input_ids
  61. class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
  62. """
  63. This module learns positional embeddings up to a fixed maximum size.
  64. """
  65. def __init__(self, num_embeddings: int, embedding_dim: int):
  66. super().__init__(num_embeddings, embedding_dim)
  67. def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
  68. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  69. bsz, seq_len = input_ids_shape[:2]
  70. positions = torch.arange(
  71. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  72. )
  73. return super().forward(positions)
  74. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
  75. class BlenderbotScaledWordEmbedding(nn.Embedding):
  76. """
  77. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  78. """
  79. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  80. super().__init__(num_embeddings, embedding_dim, padding_idx)
  81. self.embed_scale = embed_scale
  82. def forward(self, input_ids: torch.Tensor):
  83. return super().forward(input_ids) * self.embed_scale
  84. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
  85. class BlenderbotAttention(nn.Module):
  86. """Multi-headed attention from 'Attention Is All You Need' paper"""
  87. def __init__(
  88. self,
  89. embed_dim: int,
  90. num_heads: int,
  91. dropout: float = 0.0,
  92. is_decoder: bool = False,
  93. bias: bool = True,
  94. is_causal: bool = False,
  95. config: Optional[BlenderbotConfig] = None,
  96. ):
  97. super().__init__()
  98. self.embed_dim = embed_dim
  99. self.num_heads = num_heads
  100. self.dropout = dropout
  101. self.head_dim = embed_dim // num_heads
  102. self.config = config
  103. if (self.head_dim * num_heads) != self.embed_dim:
  104. raise ValueError(
  105. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  106. f" and `num_heads`: {num_heads})."
  107. )
  108. self.scaling = self.head_dim**-0.5
  109. self.is_decoder = is_decoder
  110. self.is_causal = is_causal
  111. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  112. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  113. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  114. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  115. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  116. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  117. def forward(
  118. self,
  119. hidden_states: torch.Tensor,
  120. key_value_states: Optional[torch.Tensor] = None,
  121. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  122. attention_mask: Optional[torch.Tensor] = None,
  123. layer_head_mask: Optional[torch.Tensor] = None,
  124. output_attentions: bool = False,
  125. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  126. """Input shape: Batch x Time x Channel"""
  127. # if key_value_states are provided this layer is used as a cross-attention layer
  128. # for the decoder
  129. is_cross_attention = key_value_states is not None
  130. bsz, tgt_len, _ = hidden_states.size()
  131. # get query proj
  132. query_states = self.q_proj(hidden_states) * self.scaling
  133. # get key, value proj
  134. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  135. # is checking that the `sequence_length` of the `past_key_value` is the same as
  136. # the provided `key_value_states` to support prefix tuning
  137. if (
  138. is_cross_attention
  139. and past_key_value is not None
  140. and past_key_value[0].shape[2] == key_value_states.shape[1]
  141. ):
  142. # reuse k,v, cross_attentions
  143. key_states = past_key_value[0]
  144. value_states = past_key_value[1]
  145. elif is_cross_attention:
  146. # cross_attentions
  147. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  148. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  149. elif past_key_value is not None:
  150. # reuse k, v, self_attention
  151. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  152. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  153. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  154. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  155. else:
  156. # self_attention
  157. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  158. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  159. if self.is_decoder:
  160. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  161. # Further calls to cross_attention layer can then reuse all cross-attention
  162. # key/value_states (first "if" case)
  163. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  164. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  165. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  166. # if encoder bi-directional self-attention `past_key_value` is always `None`
  167. past_key_value = (key_states, value_states)
  168. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  169. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  170. key_states = key_states.reshape(*proj_shape)
  171. value_states = value_states.reshape(*proj_shape)
  172. src_len = key_states.size(1)
  173. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  174. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  175. raise ValueError(
  176. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  177. f" {attn_weights.size()}"
  178. )
  179. if attention_mask is not None:
  180. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  181. raise ValueError(
  182. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  183. )
  184. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  185. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  186. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  187. if layer_head_mask is not None:
  188. if layer_head_mask.size() != (self.num_heads,):
  189. raise ValueError(
  190. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  191. f" {layer_head_mask.size()}"
  192. )
  193. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  194. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  195. if output_attentions:
  196. # this operation is a bit awkward, but it's required to
  197. # make sure that attn_weights keeps its gradient.
  198. # In order to do so, attn_weights have to be reshaped
  199. # twice and have to be reused in the following
  200. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  201. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  202. else:
  203. attn_weights_reshaped = None
  204. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  205. attn_output = torch.bmm(attn_probs, value_states)
  206. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  207. raise ValueError(
  208. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  209. f" {attn_output.size()}"
  210. )
  211. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  212. attn_output = attn_output.transpose(1, 2)
  213. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  214. # partitioned across GPUs when using tensor-parallelism.
  215. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  216. attn_output = self.out_proj(attn_output)
  217. return attn_output, attn_weights_reshaped, past_key_value
  218. BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
  219. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  220. class BlenderbotEncoderLayer(nn.Module):
  221. def __init__(self, config: BlenderbotConfig):
  222. super().__init__()
  223. self.embed_dim = config.d_model
  224. self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
  225. embed_dim=self.embed_dim,
  226. num_heads=config.encoder_attention_heads,
  227. dropout=config.attention_dropout,
  228. config=config,
  229. )
  230. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  231. self.dropout = config.dropout
  232. self.activation_fn = ACT2FN[config.activation_function]
  233. self.activation_dropout = config.activation_dropout
  234. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  235. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  236. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. attention_mask: torch.Tensor,
  241. layer_head_mask: torch.Tensor,
  242. output_attentions: bool = False,
  243. ) -> torch.Tensor:
  244. """
  245. Args:
  246. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  247. attention_mask (`torch.FloatTensor`): attention mask of size
  248. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  249. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  250. `(encoder_attention_heads,)`.
  251. output_attentions (`bool`, *optional*):
  252. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  253. returned tensors for more detail.
  254. """
  255. residual = hidden_states
  256. hidden_states = self.self_attn_layer_norm(hidden_states)
  257. hidden_states, attn_weights, _ = self.self_attn(
  258. hidden_states=hidden_states,
  259. attention_mask=attention_mask,
  260. layer_head_mask=layer_head_mask,
  261. output_attentions=output_attentions,
  262. )
  263. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  264. hidden_states = residual + hidden_states
  265. residual = hidden_states
  266. hidden_states = self.final_layer_norm(hidden_states)
  267. hidden_states = self.activation_fn(self.fc1(hidden_states))
  268. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  269. hidden_states = self.fc2(hidden_states)
  270. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  271. hidden_states = residual + hidden_states
  272. if hidden_states.dtype == torch.float16 and (
  273. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  274. ):
  275. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  276. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  277. outputs = (hidden_states,)
  278. if output_attentions:
  279. outputs += (attn_weights,)
  280. return outputs
  281. # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  282. class BlenderbotDecoderLayer(nn.Module):
  283. def __init__(self, config: BlenderbotConfig):
  284. super().__init__()
  285. self.embed_dim = config.d_model
  286. self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
  287. embed_dim=self.embed_dim,
  288. num_heads=config.decoder_attention_heads,
  289. dropout=config.attention_dropout,
  290. is_decoder=True,
  291. is_causal=True,
  292. config=config,
  293. )
  294. self.dropout = config.dropout
  295. self.activation_fn = ACT2FN[config.activation_function]
  296. self.activation_dropout = config.activation_dropout
  297. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  298. self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
  299. self.embed_dim,
  300. config.decoder_attention_heads,
  301. dropout=config.attention_dropout,
  302. is_decoder=True,
  303. config=config,
  304. )
  305. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  306. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  307. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  308. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  309. def forward(
  310. self,
  311. hidden_states: torch.Tensor,
  312. attention_mask: Optional[torch.Tensor] = None,
  313. encoder_hidden_states: Optional[torch.Tensor] = None,
  314. encoder_attention_mask: Optional[torch.Tensor] = None,
  315. layer_head_mask: Optional[torch.Tensor] = None,
  316. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  317. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  318. output_attentions: Optional[bool] = False,
  319. use_cache: Optional[bool] = True,
  320. ) -> torch.Tensor:
  321. """
  322. Args:
  323. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  324. attention_mask (`torch.FloatTensor`): attention mask of size
  325. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  326. encoder_hidden_states (`torch.FloatTensor`):
  327. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  328. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  329. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  330. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  331. `(encoder_attention_heads,)`.
  332. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  333. size `(decoder_attention_heads,)`.
  334. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
  335. output_attentions (`bool`, *optional*):
  336. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  337. returned tensors for more detail.
  338. """
  339. residual = hidden_states
  340. hidden_states = self.self_attn_layer_norm(hidden_states)
  341. # Self Attention
  342. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  343. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  344. # add present self-attn cache to positions 1,2 of present_key_value tuple
  345. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  346. hidden_states=hidden_states,
  347. past_key_value=self_attn_past_key_value,
  348. attention_mask=attention_mask,
  349. layer_head_mask=layer_head_mask,
  350. output_attentions=output_attentions,
  351. )
  352. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  353. hidden_states = residual + hidden_states
  354. # Cross-Attention Block
  355. cross_attn_present_key_value = None
  356. cross_attn_weights = None
  357. if encoder_hidden_states is not None:
  358. residual = hidden_states
  359. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  360. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  361. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  362. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  363. hidden_states=hidden_states,
  364. key_value_states=encoder_hidden_states,
  365. attention_mask=encoder_attention_mask,
  366. layer_head_mask=cross_attn_layer_head_mask,
  367. past_key_value=cross_attn_past_key_value,
  368. output_attentions=output_attentions,
  369. )
  370. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  371. hidden_states = residual + hidden_states
  372. # add cross-attn to positions 3,4 of present_key_value tuple
  373. present_key_value = present_key_value + cross_attn_present_key_value
  374. # Fully Connected
  375. residual = hidden_states
  376. hidden_states = self.final_layer_norm(hidden_states)
  377. hidden_states = self.activation_fn(self.fc1(hidden_states))
  378. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  379. hidden_states = self.fc2(hidden_states)
  380. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  381. hidden_states = residual + hidden_states
  382. outputs = (hidden_states,)
  383. if output_attentions:
  384. outputs += (self_attn_weights, cross_attn_weights)
  385. if use_cache:
  386. outputs += (present_key_value,)
  387. return outputs
  388. class BlenderbotPreTrainedModel(PreTrainedModel):
  389. config_class = BlenderbotConfig
  390. base_model_prefix = "model"
  391. supports_gradient_checkpointing = True
  392. def _init_weights(self, module):
  393. std = self.config.init_std
  394. if isinstance(module, nn.Linear):
  395. module.weight.data.normal_(mean=0.0, std=std)
  396. if module.bias is not None:
  397. module.bias.data.zero_()
  398. elif isinstance(module, nn.Embedding):
  399. module.weight.data.normal_(mean=0.0, std=std)
  400. if module.padding_idx is not None:
  401. module.weight.data[module.padding_idx].zero_()
  402. @property
  403. def dummy_inputs(self):
  404. pad_token = self.config.pad_token_id
  405. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  406. dummy_inputs = {
  407. "attention_mask": input_ids.ne(pad_token),
  408. "input_ids": input_ids,
  409. "decoder_input_ids": input_ids,
  410. }
  411. return dummy_inputs
  412. BLENDERBOT_START_DOCSTRING = r"""
  413. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  414. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  415. etc.)
  416. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  417. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  418. and behavior.
  419. Parameters:
  420. config ([`BlenderbotConfig`]):
  421. Model configuration class with all the parameters of the model. Initializing with a config file does not
  422. load the weights associated with the model, only the configuration. Check out the
  423. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  424. """
  425. BLENDERBOT_GENERATION_EXAMPLE = r"""
  426. Conversation example:
  427. ```python
  428. >>> from transformers import AutoTokenizer, BlenderbotForConditionalGeneration
  429. >>> mname = "facebook/blenderbot-400M-distill"
  430. >>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
  431. >>> tokenizer = AutoTokenizer.from_pretrained(mname)
  432. >>> UTTERANCE = "My friends are cool but they eat too many carbs."
  433. >>> print("Human: ", UTTERANCE)
  434. Human: My friends are cool but they eat too many carbs.
  435. >>> inputs = tokenizer([UTTERANCE], return_tensors="pt")
  436. >>> reply_ids = model.generate(**inputs)
  437. >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0])
  438. Bot: That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?
  439. >>> REPLY = "I'm not sure"
  440. >>> print("Human: ", REPLY)
  441. Human: I'm not sure
  442. >>> NEXT_UTTERANCE = (
  443. ... "My friends are cool but they eat too many carbs.</s> <s>That's unfortunate. "
  444. ... "Are they trying to lose weight or are they just trying to be healthier?</s> "
  445. ... "<s> I'm not sure."
  446. ... )
  447. >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="pt")
  448. >>> next_reply_ids = model.generate(**inputs)
  449. >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0])
  450. Bot: I see. Well, it's good that they're trying to change their eating habits.
  451. ```
  452. """
  453. BLENDERBOT_INPUTS_DOCSTRING = r"""
  454. Args:
  455. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  456. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  457. it.
  458. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  459. [`PreTrainedTokenizer.__call__`] for details.
  460. [What are input IDs?](../glossary#input-ids)
  461. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  462. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  463. - 1 for tokens that are **not masked**,
  464. - 0 for tokens that are **masked**.
  465. [What are attention masks?](../glossary#attention-mask)
  466. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  467. Indices of decoder input sequence tokens in the vocabulary.
  468. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  469. [`PreTrainedTokenizer.__call__`] for details.
  470. [What are decoder input IDs?](../glossary#decoder-input-ids)
  471. Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If
  472. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  473. `past_key_values`).
  474. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  475. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  476. be used by default.
  477. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  478. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  479. - 1 indicates the head is **not masked**,
  480. - 0 indicates the head is **masked**.
  481. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  482. Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
  483. - 1 indicates the head is **not masked**,
  484. - 0 indicates the head is **masked**.
  485. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  486. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  487. 1]`:
  488. - 1 indicates the head is **not masked**,
  489. - 0 indicates the head is **masked**.
  490. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  491. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  492. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  493. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  494. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  495. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  496. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  497. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  498. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  499. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  500. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  501. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  502. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  503. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  504. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  505. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  506. than the model's internal embedding lookup matrix.
  507. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  508. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  509. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  510. input (see `past_key_values`). This is useful if you want more control over how to convert
  511. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  512. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  513. of `inputs_embeds`.
  514. use_cache (`bool`, *optional*):
  515. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  516. `past_key_values`).
  517. output_attentions (`bool`, *optional*):
  518. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  519. tensors for more detail.
  520. output_hidden_states (`bool`, *optional*):
  521. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  522. more detail.
  523. return_dict (`bool`, *optional*):
  524. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  525. """
  526. class BlenderbotEncoder(BlenderbotPreTrainedModel):
  527. """
  528. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  529. [`BlenderbotEncoderLayer`].
  530. Args:
  531. config: BlenderbotConfig
  532. embed_tokens (nn.Embedding): output embedding
  533. """
  534. def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
  535. super().__init__(config)
  536. self.dropout = config.dropout
  537. self.layerdrop = config.encoder_layerdrop
  538. embed_dim = config.d_model
  539. self.padding_idx = config.pad_token_id
  540. self.max_source_positions = config.max_position_embeddings
  541. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  542. if embed_tokens is not None:
  543. self.embed_tokens = embed_tokens
  544. else:
  545. self.embed_tokens = BlenderbotScaledWordEmbedding(
  546. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  547. )
  548. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  549. config.max_position_embeddings,
  550. embed_dim,
  551. )
  552. self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
  553. self.layer_norm = nn.LayerNorm(config.d_model)
  554. self.gradient_checkpointing = False
  555. # Initialize weights and apply final processing
  556. self.post_init()
  557. def forward(
  558. self,
  559. input_ids=None,
  560. attention_mask=None,
  561. head_mask=None,
  562. inputs_embeds=None,
  563. output_attentions=None,
  564. output_hidden_states=None,
  565. return_dict=None,
  566. ):
  567. r"""
  568. Args:
  569. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  570. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  571. provide it.
  572. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  573. [`PreTrainedTokenizer.__call__`] for details.
  574. [What are input IDs?](../glossary#input-ids)
  575. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  576. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  577. - 1 for tokens that are **not masked**,
  578. - 0 for tokens that are **masked**.
  579. [What are attention masks?](../glossary#attention-mask)
  580. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  581. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  582. - 1 indicates the head is **not masked**,
  583. - 0 indicates the head is **masked**.
  584. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  585. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  586. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  587. than the model's internal embedding lookup matrix.
  588. output_attentions (`bool`, *optional*):
  589. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  590. returned tensors for more detail.
  591. output_hidden_states (`bool`, *optional*):
  592. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  593. for more detail.
  594. return_dict (`bool`, *optional*):
  595. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  596. """
  597. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  598. output_hidden_states = (
  599. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  600. )
  601. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  602. # retrieve input_ids and inputs_embeds
  603. if input_ids is not None and inputs_embeds is not None:
  604. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  605. elif input_ids is not None:
  606. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  607. input_shape = input_ids.size()
  608. input_ids = input_ids.view(-1, input_shape[-1])
  609. elif inputs_embeds is not None:
  610. input_shape = inputs_embeds.size()[:-1]
  611. else:
  612. raise ValueError("You have to specify either input_ids or inputs_embeds")
  613. if inputs_embeds is None:
  614. inputs_embeds = self.embed_tokens(input_ids)
  615. embed_pos = self.embed_positions(input_shape)
  616. hidden_states = inputs_embeds + embed_pos
  617. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  618. # expand attention_mask
  619. if attention_mask is not None:
  620. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  621. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  622. encoder_states = () if output_hidden_states else None
  623. all_attentions = () if output_attentions else None
  624. # check if head_mask has a correct number of layers specified if desired
  625. if head_mask is not None:
  626. if head_mask.size()[0] != len(self.layers):
  627. raise ValueError(
  628. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  629. f" {head_mask.size()[0]}."
  630. )
  631. for idx, encoder_layer in enumerate(self.layers):
  632. if output_hidden_states:
  633. encoder_states = encoder_states + (hidden_states,)
  634. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  635. to_drop = False
  636. if self.training:
  637. dropout_probability = torch.rand([])
  638. if dropout_probability < self.layerdrop: # skip the layer
  639. to_drop = True
  640. if to_drop:
  641. layer_outputs = (None, None)
  642. else:
  643. if self.gradient_checkpointing and self.training:
  644. layer_outputs = self._gradient_checkpointing_func(
  645. encoder_layer.__call__,
  646. hidden_states,
  647. attention_mask,
  648. (head_mask[idx] if head_mask is not None else None),
  649. output_attentions,
  650. )
  651. else:
  652. layer_outputs = encoder_layer(
  653. hidden_states,
  654. attention_mask,
  655. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  656. output_attentions=output_attentions,
  657. )
  658. hidden_states = layer_outputs[0]
  659. if output_attentions:
  660. all_attentions = all_attentions + (layer_outputs[1],)
  661. # add final layer norm
  662. hidden_states = self.layer_norm(hidden_states)
  663. if output_hidden_states:
  664. encoder_states = encoder_states + (hidden_states,)
  665. if not return_dict:
  666. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  667. return BaseModelOutput(
  668. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  669. )
  670. class BlenderbotDecoder(BlenderbotPreTrainedModel):
  671. """
  672. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BlenderbotDecoderLayer`]
  673. Args:
  674. config: BlenderbotConfig
  675. embed_tokens (nn.Embedding): output embedding
  676. """
  677. def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
  678. super().__init__(config)
  679. self.dropout = config.dropout
  680. self.layerdrop = config.decoder_layerdrop
  681. self.padding_idx = config.pad_token_id
  682. self.max_target_positions = config.max_position_embeddings
  683. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  684. if embed_tokens is not None:
  685. self.embed_tokens = embed_tokens
  686. else:
  687. self.embed_tokens = BlenderbotScaledWordEmbedding(
  688. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  689. )
  690. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  691. config.max_position_embeddings,
  692. config.d_model,
  693. )
  694. self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)])
  695. self.layer_norm = nn.LayerNorm(config.d_model)
  696. self.gradient_checkpointing = False
  697. # Initialize weights and apply final processing
  698. self.post_init()
  699. def get_input_embeddings(self):
  700. return self.embed_tokens
  701. def set_input_embeddings(self, value):
  702. self.embed_tokens = value
  703. def forward(
  704. self,
  705. input_ids=None,
  706. attention_mask=None,
  707. encoder_hidden_states=None,
  708. encoder_attention_mask=None,
  709. head_mask=None,
  710. cross_attn_head_mask=None,
  711. past_key_values=None,
  712. inputs_embeds=None,
  713. use_cache=None,
  714. output_attentions=None,
  715. output_hidden_states=None,
  716. return_dict=None,
  717. ):
  718. r"""
  719. Args:
  720. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  721. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  722. provide it.
  723. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  724. [`PreTrainedTokenizer.__call__`] for details.
  725. [What are input IDs?](../glossary#input-ids)
  726. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  727. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  728. - 1 for tokens that are **not masked**,
  729. - 0 for tokens that are **masked**.
  730. [What are attention masks?](../glossary#attention-mask)
  731. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  732. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  733. of the decoder.
  734. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  735. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  736. selected in `[0, 1]`:
  737. - 1 for tokens that are **not masked**,
  738. - 0 for tokens that are **masked**.
  739. [What are attention masks?](../glossary#attention-mask)
  740. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  741. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0,
  742. 1]`:
  743. - 1 indicates the head is **not masked**,
  744. - 0 indicates the head is **masked**.
  745. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  746. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  747. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  748. - 1 indicates the head is **not masked**,
  749. - 0 indicates the head is **masked**.
  750. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  751. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  752. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  753. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  754. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  755. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  756. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  757. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  758. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  759. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  760. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  761. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  762. than the model's internal embedding lookup matrix.
  763. output_attentions (`bool`, *optional*):
  764. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  765. returned tensors for more detail.
  766. output_hidden_states (`bool`, *optional*):
  767. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  768. for more detail.
  769. return_dict (`bool`, *optional*):
  770. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  771. """
  772. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  773. output_hidden_states = (
  774. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  775. )
  776. use_cache = use_cache if use_cache is not None else self.config.use_cache
  777. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  778. # retrieve input_ids and inputs_embeds
  779. if input_ids is not None and inputs_embeds is not None:
  780. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  781. elif input_ids is not None:
  782. input_shape = input_ids.size()
  783. input_ids = input_ids.view(-1, input_shape[-1])
  784. elif inputs_embeds is not None:
  785. input_shape = inputs_embeds.size()[:-1]
  786. else:
  787. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  788. # past_key_values_length
  789. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  790. if inputs_embeds is None:
  791. inputs_embeds = self.embed_tokens(input_ids)
  792. attention_mask = _prepare_4d_causal_attention_mask(
  793. attention_mask, input_shape, inputs_embeds, past_key_values_length
  794. )
  795. # expand encoder attention mask
  796. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  797. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  798. encoder_attention_mask = _prepare_4d_attention_mask(
  799. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  800. )
  801. # embed positions
  802. positions = self.embed_positions(input_shape, past_key_values_length)
  803. hidden_states = inputs_embeds + positions
  804. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  805. if self.gradient_checkpointing and self.training:
  806. if use_cache:
  807. logger.warning(
  808. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  809. )
  810. use_cache = False
  811. # decoder layers
  812. all_hidden_states = () if output_hidden_states else None
  813. all_self_attns = () if output_attentions else None
  814. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  815. next_decoder_cache = () if use_cache else None
  816. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  817. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  818. if attn_mask is not None:
  819. if attn_mask.size()[0] != len(self.layers):
  820. raise ValueError(
  821. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  822. f" {head_mask.size()[0]}."
  823. )
  824. for idx, decoder_layer in enumerate(self.layers):
  825. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  826. if output_hidden_states:
  827. all_hidden_states += (hidden_states,)
  828. if self.training:
  829. dropout_probability = torch.rand([])
  830. if dropout_probability < self.layerdrop:
  831. continue
  832. past_key_value = past_key_values[idx] if past_key_values is not None else None
  833. if self.gradient_checkpointing and self.training:
  834. layer_outputs = self._gradient_checkpointing_func(
  835. decoder_layer.__call__,
  836. hidden_states,
  837. attention_mask,
  838. encoder_hidden_states,
  839. encoder_attention_mask,
  840. head_mask[idx] if head_mask is not None else None,
  841. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
  842. None,
  843. output_attentions,
  844. use_cache,
  845. )
  846. else:
  847. layer_outputs = decoder_layer(
  848. hidden_states,
  849. attention_mask=attention_mask,
  850. encoder_hidden_states=encoder_hidden_states,
  851. encoder_attention_mask=encoder_attention_mask,
  852. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  853. cross_attn_layer_head_mask=(
  854. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  855. ),
  856. past_key_value=past_key_value,
  857. output_attentions=output_attentions,
  858. use_cache=use_cache,
  859. )
  860. hidden_states = layer_outputs[0]
  861. if use_cache:
  862. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  863. if output_attentions:
  864. all_self_attns += (layer_outputs[1],)
  865. if encoder_hidden_states is not None:
  866. all_cross_attentions += (layer_outputs[2],)
  867. # add final layer norm
  868. hidden_states = self.layer_norm(hidden_states)
  869. # add hidden states from the last decoder layer
  870. if output_hidden_states:
  871. all_hidden_states += (hidden_states,)
  872. next_cache = next_decoder_cache if use_cache else None
  873. if not return_dict:
  874. return tuple(
  875. v
  876. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
  877. if v is not None
  878. )
  879. return BaseModelOutputWithPastAndCrossAttentions(
  880. last_hidden_state=hidden_states,
  881. past_key_values=next_cache,
  882. hidden_states=all_hidden_states,
  883. attentions=all_self_attns,
  884. cross_attentions=all_cross_attentions,
  885. )
  886. @add_start_docstrings(
  887. "The bare Blenderbot Model outputting raw hidden-states without any specific head on top.",
  888. BLENDERBOT_START_DOCSTRING,
  889. )
  890. class BlenderbotModel(BlenderbotPreTrainedModel):
  891. _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
  892. def __init__(self, config: BlenderbotConfig):
  893. super().__init__(config)
  894. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  895. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  896. self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  897. self.encoder = BlenderbotEncoder(config, self.shared)
  898. self.decoder = BlenderbotDecoder(config, self.shared)
  899. # Initialize weights and apply final processing
  900. self.post_init()
  901. @classmethod
  902. def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
  903. if pretrained_model_name_or_path == "facebook/blenderbot-90M":
  904. warnings.warn(
  905. "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical"
  906. " checkpoint `facebook/small_blenderbot-90M` with"
  907. " `BlenderbotSmallModel.from_pretrained('facebook/small_blenderbot-90M')` instead.",
  908. FutureWarning,
  909. )
  910. return BlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
  911. return super(BlenderbotModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  912. def get_input_embeddings(self):
  913. return self.shared
  914. def set_input_embeddings(self, value):
  915. self.shared = value
  916. self.encoder.embed_tokens = self.shared
  917. self.decoder.embed_tokens = self.shared
  918. def get_encoder(self):
  919. return self.encoder
  920. def get_decoder(self):
  921. return self.decoder
  922. @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING)
  923. @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
  924. def forward(
  925. self,
  926. input_ids: Optional[torch.LongTensor] = None,
  927. attention_mask: Optional[torch.Tensor] = None,
  928. decoder_input_ids: Optional[torch.LongTensor] = None,
  929. decoder_attention_mask: Optional[torch.LongTensor] = None,
  930. head_mask: Optional[torch.Tensor] = None,
  931. decoder_head_mask: Optional[torch.Tensor] = None,
  932. cross_attn_head_mask: Optional[torch.Tensor] = None,
  933. encoder_outputs: Optional[Union[Tuple, BaseModelOutput]] = None,
  934. past_key_values: Optional[List[torch.FloatTensor]] = None,
  935. inputs_embeds: Optional[torch.Tensor] = None,
  936. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  937. use_cache: Optional[bool] = None,
  938. output_attentions: Optional[bool] = None,
  939. output_hidden_states: Optional[bool] = None,
  940. return_dict: Optional[bool] = None,
  941. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  942. r"""
  943. Returns:
  944. Example:
  945. ```python
  946. >>> from transformers import AutoTokenizer, BlenderbotModel
  947. >>> model = BlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill")
  948. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  949. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  950. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  951. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_input_ids)
  952. >>> last_hidden_states = outputs.last_hidden_state
  953. >>> list(last_hidden_states.shape)
  954. [1, 6, 1280]
  955. ```"""
  956. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  957. output_hidden_states = (
  958. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  959. )
  960. use_cache = use_cache if use_cache is not None else self.config.use_cache
  961. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  962. if encoder_outputs is None:
  963. encoder_outputs = self.encoder(
  964. input_ids=input_ids,
  965. attention_mask=attention_mask,
  966. head_mask=head_mask,
  967. inputs_embeds=inputs_embeds,
  968. output_attentions=output_attentions,
  969. output_hidden_states=output_hidden_states,
  970. return_dict=return_dict,
  971. )
  972. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  973. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  974. encoder_outputs = BaseModelOutput(
  975. last_hidden_state=encoder_outputs[0],
  976. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  977. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  978. )
  979. # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  980. decoder_outputs = self.decoder(
  981. input_ids=decoder_input_ids,
  982. attention_mask=decoder_attention_mask,
  983. encoder_hidden_states=encoder_outputs[0],
  984. encoder_attention_mask=attention_mask,
  985. head_mask=decoder_head_mask,
  986. cross_attn_head_mask=cross_attn_head_mask,
  987. past_key_values=past_key_values,
  988. inputs_embeds=decoder_inputs_embeds,
  989. use_cache=use_cache,
  990. output_attentions=output_attentions,
  991. output_hidden_states=output_hidden_states,
  992. return_dict=return_dict,
  993. )
  994. if not return_dict:
  995. return decoder_outputs + encoder_outputs
  996. return Seq2SeqModelOutput(
  997. last_hidden_state=decoder_outputs.last_hidden_state,
  998. past_key_values=decoder_outputs.past_key_values,
  999. decoder_hidden_states=decoder_outputs.hidden_states,
  1000. decoder_attentions=decoder_outputs.attentions,
  1001. cross_attentions=decoder_outputs.cross_attentions,
  1002. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1003. encoder_hidden_states=encoder_outputs.hidden_states,
  1004. encoder_attentions=encoder_outputs.attentions,
  1005. )
  1006. @add_start_docstrings(
  1007. "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING
  1008. )
  1009. class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
  1010. base_model_prefix = "model"
  1011. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1012. _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
  1013. def __init__(self, config: BlenderbotConfig):
  1014. super().__init__(config)
  1015. self.model = BlenderbotModel(config)
  1016. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1017. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1018. # Initialize weights and apply final processing
  1019. self.post_init()
  1020. @classmethod
  1021. def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
  1022. if pretrained_model_name_or_path == "facebook/blenderbot-90M":
  1023. warnings.warn(
  1024. "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical"
  1025. " checkpoint `facebook/small_blenderbot-90M` with"
  1026. " `BlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
  1027. FutureWarning,
  1028. )
  1029. return BlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
  1030. return super(BlenderbotForConditionalGeneration, cls).from_pretrained(
  1031. pretrained_model_name_or_path, *model_args, **kwargs
  1032. )
  1033. def get_encoder(self):
  1034. return self.model.get_encoder()
  1035. def get_decoder(self):
  1036. return self.model.get_decoder()
  1037. def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
  1038. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
  1039. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1040. return new_embeddings
  1041. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1042. old_num_tokens = self.final_logits_bias.shape[-1]
  1043. if new_num_tokens <= old_num_tokens:
  1044. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1045. else:
  1046. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1047. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1048. self.register_buffer("final_logits_bias", new_bias)
  1049. def get_output_embeddings(self):
  1050. return self.lm_head
  1051. def set_output_embeddings(self, new_embeddings):
  1052. self.lm_head = new_embeddings
  1053. @add_start_docstrings_to_model_forward(BLENDERBOT_INPUTS_DOCSTRING)
  1054. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1055. @add_end_docstrings(BLENDERBOT_GENERATION_EXAMPLE)
  1056. def forward(
  1057. self,
  1058. input_ids: Optional[torch.LongTensor] = None,
  1059. attention_mask: Optional[torch.Tensor] = None,
  1060. decoder_input_ids: Optional[torch.LongTensor] = None,
  1061. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1062. head_mask: Optional[torch.Tensor] = None,
  1063. decoder_head_mask: Optional[torch.Tensor] = None,
  1064. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1065. encoder_outputs: Optional[Union[Tuple, BaseModelOutput]] = None,
  1066. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1067. inputs_embeds: Optional[torch.Tensor] = None,
  1068. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1069. labels: Optional[torch.LongTensor] = None,
  1070. use_cache: Optional[bool] = None,
  1071. output_attentions: Optional[bool] = None,
  1072. output_hidden_states: Optional[bool] = None,
  1073. return_dict: Optional[bool] = None,
  1074. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1075. r"""
  1076. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1077. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1078. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1079. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1080. Returns:
  1081. """
  1082. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1083. if labels is not None:
  1084. if use_cache:
  1085. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1086. use_cache = False
  1087. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1088. decoder_input_ids = shift_tokens_right(
  1089. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1090. )
  1091. outputs = self.model(
  1092. input_ids,
  1093. attention_mask=attention_mask,
  1094. decoder_input_ids=decoder_input_ids,
  1095. encoder_outputs=encoder_outputs,
  1096. decoder_attention_mask=decoder_attention_mask,
  1097. head_mask=head_mask,
  1098. decoder_head_mask=decoder_head_mask,
  1099. cross_attn_head_mask=cross_attn_head_mask,
  1100. past_key_values=past_key_values,
  1101. inputs_embeds=inputs_embeds,
  1102. decoder_inputs_embeds=decoder_inputs_embeds,
  1103. use_cache=use_cache,
  1104. output_attentions=output_attentions,
  1105. output_hidden_states=output_hidden_states,
  1106. return_dict=return_dict,
  1107. )
  1108. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1109. masked_lm_loss = None
  1110. if labels is not None:
  1111. loss_fct = CrossEntropyLoss()
  1112. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1113. if not return_dict:
  1114. output = (lm_logits,) + outputs[1:]
  1115. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1116. return Seq2SeqLMOutput(
  1117. loss=masked_lm_loss,
  1118. logits=lm_logits,
  1119. past_key_values=outputs.past_key_values,
  1120. decoder_hidden_states=outputs.decoder_hidden_states,
  1121. decoder_attentions=outputs.decoder_attentions,
  1122. cross_attentions=outputs.cross_attentions,
  1123. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1124. encoder_hidden_states=outputs.encoder_hidden_states,
  1125. encoder_attentions=outputs.encoder_attentions,
  1126. )
  1127. @staticmethod
  1128. def _reorder_cache(past_key_values, beam_idx):
  1129. reordered_past = ()
  1130. for layer_past in past_key_values:
  1131. # cached cross_attention states don't have to be reordered -> they are always the same
  1132. reordered_past += (
  1133. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
  1134. + layer_past[2:],
  1135. )
  1136. return reordered_past
  1137. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot
  1138. class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
  1139. """
  1140. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1141. used in combination with the [`EncoderDecoderModel`] framework.
  1142. """
  1143. def __init__(self, config):
  1144. super().__init__(config)
  1145. self.decoder = BlenderbotDecoder(config)
  1146. def forward(self, *args, **kwargs):
  1147. return self.decoder(*args, **kwargs)
  1148. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
  1149. class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
  1150. _tied_weights_keys = ["lm_head.weight"]
  1151. def __init__(self, config):
  1152. config = copy.deepcopy(config)
  1153. config.is_decoder = True
  1154. config.is_encoder_decoder = False
  1155. super().__init__(config)
  1156. self.model = BlenderbotDecoderWrapper(config)
  1157. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1158. # Initialize weights and apply final processing
  1159. self.post_init()
  1160. def get_input_embeddings(self):
  1161. return self.model.decoder.embed_tokens
  1162. def set_input_embeddings(self, value):
  1163. self.model.decoder.embed_tokens = value
  1164. def get_output_embeddings(self):
  1165. return self.lm_head
  1166. def set_output_embeddings(self, new_embeddings):
  1167. self.lm_head = new_embeddings
  1168. def set_decoder(self, decoder):
  1169. self.model.decoder = decoder
  1170. def get_decoder(self):
  1171. return self.model.decoder
  1172. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
  1173. def forward(
  1174. self,
  1175. input_ids: torch.LongTensor = None,
  1176. attention_mask: Optional[torch.Tensor] = None,
  1177. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1178. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1179. head_mask: Optional[torch.Tensor] = None,
  1180. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1181. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1182. inputs_embeds: Optional[torch.FloatTensor] = None,
  1183. labels: Optional[torch.LongTensor] = None,
  1184. use_cache: Optional[bool] = None,
  1185. output_attentions: Optional[bool] = None,
  1186. output_hidden_states: Optional[bool] = None,
  1187. return_dict: Optional[bool] = None,
  1188. ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
  1189. r"""
  1190. Args:
  1191. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1192. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1193. provide it.
  1194. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1195. [`PreTrainedTokenizer.__call__`] for details.
  1196. [What are input IDs?](../glossary#input-ids)
  1197. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1198. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1199. - 1 for tokens that are **not masked**,
  1200. - 0 for tokens that are **masked**.
  1201. [What are attention masks?](../glossary#attention-mask)
  1202. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1203. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1204. if the model is configured as a decoder.
  1205. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1206. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
  1207. in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1208. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1209. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1210. - 1 indicates the head is **not masked**,
  1211. - 0 indicates the head is **masked**.
  1212. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1213. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1214. - 1 indicates the head is **not masked**,
  1215. - 0 indicates the head is **masked**.
  1216. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1217. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  1218. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  1219. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
  1220. tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
  1221. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1222. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1223. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1224. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1225. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1226. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1227. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1228. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1229. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1230. use_cache (`bool`, *optional*):
  1231. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  1232. (see `past_key_values`).
  1233. - 1 for tokens that are **not masked**,
  1234. - 0 for tokens that are **masked**.
  1235. output_attentions (`bool`, *optional*):
  1236. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1237. returned tensors for more detail.
  1238. output_hidden_states (`bool`, *optional*):
  1239. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1240. for more detail.
  1241. return_dict (`bool`, *optional*):
  1242. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1243. Returns:
  1244. Example:
  1245. ```python
  1246. >>> from transformers import AutoTokenizer, BlenderbotForCausalLM
  1247. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  1248. >>> model = BlenderbotForCausalLM.from_pretrained("facebook/blenderbot-400M-distill", add_cross_attention=False)
  1249. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1250. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1251. >>> outputs = model(**inputs)
  1252. >>> logits = outputs.logits
  1253. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1254. >>> list(logits.shape) == expected_shape
  1255. True
  1256. ```"""
  1257. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1258. output_hidden_states = (
  1259. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1260. )
  1261. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1262. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1263. outputs = self.model.decoder(
  1264. input_ids=input_ids,
  1265. attention_mask=attention_mask,
  1266. encoder_hidden_states=encoder_hidden_states,
  1267. encoder_attention_mask=encoder_attention_mask,
  1268. head_mask=head_mask,
  1269. cross_attn_head_mask=cross_attn_head_mask,
  1270. past_key_values=past_key_values,
  1271. inputs_embeds=inputs_embeds,
  1272. use_cache=use_cache,
  1273. output_attentions=output_attentions,
  1274. output_hidden_states=output_hidden_states,
  1275. return_dict=return_dict,
  1276. )
  1277. logits = self.lm_head(outputs[0])
  1278. loss = None
  1279. if labels is not None:
  1280. labels = labels.to(logits.device)
  1281. loss_fct = CrossEntropyLoss()
  1282. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1283. if not return_dict:
  1284. output = (logits,) + outputs[1:]
  1285. return (loss,) + output if loss is not None else output
  1286. return CausalLMOutputWithCrossAttentions(
  1287. loss=loss,
  1288. logits=logits,
  1289. past_key_values=outputs.past_key_values,
  1290. hidden_states=outputs.hidden_states,
  1291. attentions=outputs.attentions,
  1292. cross_attentions=outputs.cross_attentions,
  1293. )
  1294. @staticmethod
  1295. def _reorder_cache(past_key_values, beam_idx):
  1296. reordered_past = ()
  1297. for layer_past in past_key_values:
  1298. reordered_past += (
  1299. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1300. )
  1301. return reordered_past