modeling_bert.py 88 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch BERT model."""
  17. import math
  18. import os
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import List, Optional, Tuple, Union
  22. import torch
  23. import torch.utils.checkpoint
  24. from packaging import version
  25. from torch import nn
  26. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  27. from ...activations import ACT2FN
  28. from ...generation import GenerationMixin
  29. from ...modeling_attn_mask_utils import (
  30. _prepare_4d_attention_mask_for_sdpa,
  31. _prepare_4d_causal_attention_mask_for_sdpa,
  32. )
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithPastAndCrossAttentions,
  35. BaseModelOutputWithPoolingAndCrossAttentions,
  36. CausalLMOutputWithCrossAttentions,
  37. MaskedLMOutput,
  38. MultipleChoiceModelOutput,
  39. NextSentencePredictorOutput,
  40. QuestionAnsweringModelOutput,
  41. SequenceClassifierOutput,
  42. TokenClassifierOutput,
  43. )
  44. from ...modeling_utils import PreTrainedModel
  45. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  46. from ...utils import (
  47. ModelOutput,
  48. add_code_sample_docstrings,
  49. add_start_docstrings,
  50. add_start_docstrings_to_model_forward,
  51. get_torch_version,
  52. logging,
  53. replace_return_docstrings,
  54. )
  55. from .configuration_bert import BertConfig
  56. logger = logging.get_logger(__name__)
  57. _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
  58. _CONFIG_FOR_DOC = "BertConfig"
  59. # TokenClassification docstring
  60. _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
  61. _TOKEN_CLASS_EXPECTED_OUTPUT = (
  62. "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
  63. )
  64. _TOKEN_CLASS_EXPECTED_LOSS = 0.01
  65. # QuestionAnswering docstring
  66. _CHECKPOINT_FOR_QA = "deepset/bert-base-cased-squad2"
  67. _QA_EXPECTED_OUTPUT = "'a nice puppet'"
  68. _QA_EXPECTED_LOSS = 7.41
  69. _QA_TARGET_START_INDEX = 14
  70. _QA_TARGET_END_INDEX = 15
  71. # SequenceClassification docstring
  72. _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
  73. _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
  74. _SEQ_CLASS_EXPECTED_LOSS = 0.01
  75. def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
  76. """Load tf checkpoints in a pytorch model."""
  77. try:
  78. import re
  79. import numpy as np
  80. import tensorflow as tf
  81. except ImportError:
  82. logger.error(
  83. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  84. "https://www.tensorflow.org/install/ for installation instructions."
  85. )
  86. raise
  87. tf_path = os.path.abspath(tf_checkpoint_path)
  88. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  89. # Load weights from TF model
  90. init_vars = tf.train.list_variables(tf_path)
  91. names = []
  92. arrays = []
  93. for name, shape in init_vars:
  94. logger.info(f"Loading TF weight {name} with shape {shape}")
  95. array = tf.train.load_variable(tf_path, name)
  96. names.append(name)
  97. arrays.append(array)
  98. for name, array in zip(names, arrays):
  99. name = name.split("/")
  100. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  101. # which are not required for using pretrained model
  102. if any(
  103. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  104. for n in name
  105. ):
  106. logger.info(f"Skipping {'/'.join(name)}")
  107. continue
  108. pointer = model
  109. for m_name in name:
  110. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  111. scope_names = re.split(r"_(\d+)", m_name)
  112. else:
  113. scope_names = [m_name]
  114. if scope_names[0] == "kernel" or scope_names[0] == "gamma":
  115. pointer = getattr(pointer, "weight")
  116. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  117. pointer = getattr(pointer, "bias")
  118. elif scope_names[0] == "output_weights":
  119. pointer = getattr(pointer, "weight")
  120. elif scope_names[0] == "squad":
  121. pointer = getattr(pointer, "classifier")
  122. else:
  123. try:
  124. pointer = getattr(pointer, scope_names[0])
  125. except AttributeError:
  126. logger.info(f"Skipping {'/'.join(name)}")
  127. continue
  128. if len(scope_names) >= 2:
  129. num = int(scope_names[1])
  130. pointer = pointer[num]
  131. if m_name[-11:] == "_embeddings":
  132. pointer = getattr(pointer, "weight")
  133. elif m_name == "kernel":
  134. array = np.transpose(array)
  135. try:
  136. if pointer.shape != array.shape:
  137. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  138. except ValueError as e:
  139. e.args += (pointer.shape, array.shape)
  140. raise
  141. logger.info(f"Initialize PyTorch weight {name}")
  142. pointer.data = torch.from_numpy(array)
  143. return model
  144. class BertEmbeddings(nn.Module):
  145. """Construct the embeddings from word, position and token_type embeddings."""
  146. def __init__(self, config):
  147. super().__init__()
  148. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  149. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  150. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  151. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  152. # any TensorFlow checkpoint file
  153. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  154. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  155. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  156. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  157. self.register_buffer(
  158. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  159. )
  160. self.register_buffer(
  161. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  162. )
  163. def forward(
  164. self,
  165. input_ids: Optional[torch.LongTensor] = None,
  166. token_type_ids: Optional[torch.LongTensor] = None,
  167. position_ids: Optional[torch.LongTensor] = None,
  168. inputs_embeds: Optional[torch.FloatTensor] = None,
  169. past_key_values_length: int = 0,
  170. ) -> torch.Tensor:
  171. if input_ids is not None:
  172. input_shape = input_ids.size()
  173. else:
  174. input_shape = inputs_embeds.size()[:-1]
  175. seq_length = input_shape[1]
  176. if position_ids is None:
  177. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  178. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  179. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  180. # issue #5664
  181. if token_type_ids is None:
  182. if hasattr(self, "token_type_ids"):
  183. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  184. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  185. token_type_ids = buffered_token_type_ids_expanded
  186. else:
  187. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  188. if inputs_embeds is None:
  189. inputs_embeds = self.word_embeddings(input_ids)
  190. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  191. embeddings = inputs_embeds + token_type_embeddings
  192. if self.position_embedding_type == "absolute":
  193. position_embeddings = self.position_embeddings(position_ids)
  194. embeddings += position_embeddings
  195. embeddings = self.LayerNorm(embeddings)
  196. embeddings = self.dropout(embeddings)
  197. return embeddings
  198. class BertSelfAttention(nn.Module):
  199. def __init__(self, config, position_embedding_type=None):
  200. super().__init__()
  201. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  202. raise ValueError(
  203. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  204. f"heads ({config.num_attention_heads})"
  205. )
  206. self.num_attention_heads = config.num_attention_heads
  207. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  208. self.all_head_size = self.num_attention_heads * self.attention_head_size
  209. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  210. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  211. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  212. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  213. self.position_embedding_type = position_embedding_type or getattr(
  214. config, "position_embedding_type", "absolute"
  215. )
  216. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  217. self.max_position_embeddings = config.max_position_embeddings
  218. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  219. self.is_decoder = config.is_decoder
  220. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  221. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  222. x = x.view(new_x_shape)
  223. return x.permute(0, 2, 1, 3)
  224. def forward(
  225. self,
  226. hidden_states: torch.Tensor,
  227. attention_mask: Optional[torch.FloatTensor] = None,
  228. head_mask: Optional[torch.FloatTensor] = None,
  229. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  230. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  231. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  232. output_attentions: Optional[bool] = False,
  233. ) -> Tuple[torch.Tensor]:
  234. mixed_query_layer = self.query(hidden_states)
  235. # If this is instantiated as a cross-attention module, the keys
  236. # and values come from an encoder; the attention mask needs to be
  237. # such that the encoder's padding tokens are not attended to.
  238. is_cross_attention = encoder_hidden_states is not None
  239. if is_cross_attention and past_key_value is not None:
  240. # reuse k,v, cross_attentions
  241. key_layer = past_key_value[0]
  242. value_layer = past_key_value[1]
  243. attention_mask = encoder_attention_mask
  244. elif is_cross_attention:
  245. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  246. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  247. attention_mask = encoder_attention_mask
  248. elif past_key_value is not None:
  249. key_layer = self.transpose_for_scores(self.key(hidden_states))
  250. value_layer = self.transpose_for_scores(self.value(hidden_states))
  251. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  252. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  253. else:
  254. key_layer = self.transpose_for_scores(self.key(hidden_states))
  255. value_layer = self.transpose_for_scores(self.value(hidden_states))
  256. query_layer = self.transpose_for_scores(mixed_query_layer)
  257. use_cache = past_key_value is not None
  258. if self.is_decoder:
  259. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  260. # Further calls to cross_attention layer can then reuse all cross-attention
  261. # key/value_states (first "if" case)
  262. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  263. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  264. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  265. # if encoder bi-directional self-attention `past_key_value` is always `None`
  266. past_key_value = (key_layer, value_layer)
  267. # Take the dot product between "query" and "key" to get the raw attention scores.
  268. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  269. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  270. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  271. if use_cache:
  272. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  273. -1, 1
  274. )
  275. else:
  276. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  277. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  278. distance = position_ids_l - position_ids_r
  279. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  280. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  281. if self.position_embedding_type == "relative_key":
  282. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  283. attention_scores = attention_scores + relative_position_scores
  284. elif self.position_embedding_type == "relative_key_query":
  285. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  286. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  287. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  288. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  289. if attention_mask is not None:
  290. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  291. attention_scores = attention_scores + attention_mask
  292. # Normalize the attention scores to probabilities.
  293. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  294. # This is actually dropping out entire tokens to attend to, which might
  295. # seem a bit unusual, but is taken from the original Transformer paper.
  296. attention_probs = self.dropout(attention_probs)
  297. # Mask heads if we want to
  298. if head_mask is not None:
  299. attention_probs = attention_probs * head_mask
  300. context_layer = torch.matmul(attention_probs, value_layer)
  301. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  302. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  303. context_layer = context_layer.view(new_context_layer_shape)
  304. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  305. if self.is_decoder:
  306. outputs = outputs + (past_key_value,)
  307. return outputs
  308. class BertSdpaSelfAttention(BertSelfAttention):
  309. def __init__(self, config, position_embedding_type=None):
  310. super().__init__(config, position_embedding_type=position_embedding_type)
  311. self.dropout_prob = config.attention_probs_dropout_prob
  312. self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
  313. # Adapted from BertSelfAttention
  314. def forward(
  315. self,
  316. hidden_states: torch.Tensor,
  317. attention_mask: Optional[torch.Tensor] = None,
  318. head_mask: Optional[torch.FloatTensor] = None,
  319. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  320. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  321. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  322. output_attentions: Optional[bool] = False,
  323. ) -> Tuple[torch.Tensor]:
  324. if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
  325. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
  326. logger.warning_once(
  327. "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  328. "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
  329. "the manual attention implementation, but specifying the manual implementation will be required from "
  330. "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
  331. '`attn_implementation="eager"` when loading the model.'
  332. )
  333. return super().forward(
  334. hidden_states,
  335. attention_mask,
  336. head_mask,
  337. encoder_hidden_states,
  338. encoder_attention_mask,
  339. past_key_value,
  340. output_attentions,
  341. )
  342. bsz, tgt_len, _ = hidden_states.size()
  343. query_layer = self.transpose_for_scores(self.query(hidden_states))
  344. # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
  345. # mask needs to be such that the encoder's padding tokens are not attended to.
  346. is_cross_attention = encoder_hidden_states is not None
  347. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  348. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  349. # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
  350. if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
  351. key_layer, value_layer = past_key_value
  352. else:
  353. key_layer = self.transpose_for_scores(self.key(current_states))
  354. value_layer = self.transpose_for_scores(self.value(current_states))
  355. if past_key_value is not None and not is_cross_attention:
  356. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  357. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  358. if self.is_decoder:
  359. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  360. # Further calls to cross_attention layer can then reuse all cross-attention
  361. # key/value_states (first "if" case)
  362. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  363. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  364. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  365. # if encoder bi-directional self-attention `past_key_value` is always `None`
  366. past_key_value = (key_layer, value_layer)
  367. # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
  368. # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
  369. # Reference: https://github.com/pytorch/pytorch/issues/112577
  370. if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
  371. query_layer = query_layer.contiguous()
  372. key_layer = key_layer.contiguous()
  373. value_layer = value_layer.contiguous()
  374. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  375. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  376. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
  377. # a causal mask in case tgt_len == 1.
  378. is_causal = (
  379. True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
  380. )
  381. attn_output = torch.nn.functional.scaled_dot_product_attention(
  382. query_layer,
  383. key_layer,
  384. value_layer,
  385. attn_mask=attention_mask,
  386. dropout_p=self.dropout_prob if self.training else 0.0,
  387. is_causal=is_causal,
  388. )
  389. attn_output = attn_output.transpose(1, 2)
  390. attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
  391. outputs = (attn_output,)
  392. if self.is_decoder:
  393. outputs = outputs + (past_key_value,)
  394. return outputs
  395. class BertSelfOutput(nn.Module):
  396. def __init__(self, config):
  397. super().__init__()
  398. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  399. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  400. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  401. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  402. hidden_states = self.dense(hidden_states)
  403. hidden_states = self.dropout(hidden_states)
  404. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  405. return hidden_states
  406. BERT_SELF_ATTENTION_CLASSES = {
  407. "eager": BertSelfAttention,
  408. "sdpa": BertSdpaSelfAttention,
  409. }
  410. class BertAttention(nn.Module):
  411. def __init__(self, config, position_embedding_type=None):
  412. super().__init__()
  413. self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
  414. config, position_embedding_type=position_embedding_type
  415. )
  416. self.output = BertSelfOutput(config)
  417. self.pruned_heads = set()
  418. def prune_heads(self, heads):
  419. if len(heads) == 0:
  420. return
  421. heads, index = find_pruneable_heads_and_indices(
  422. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  423. )
  424. # Prune linear layers
  425. self.self.query = prune_linear_layer(self.self.query, index)
  426. self.self.key = prune_linear_layer(self.self.key, index)
  427. self.self.value = prune_linear_layer(self.self.value, index)
  428. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  429. # Update hyper params and store pruned heads
  430. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  431. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  432. self.pruned_heads = self.pruned_heads.union(heads)
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. attention_mask: Optional[torch.FloatTensor] = None,
  437. head_mask: Optional[torch.FloatTensor] = None,
  438. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  439. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  440. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  441. output_attentions: Optional[bool] = False,
  442. ) -> Tuple[torch.Tensor]:
  443. self_outputs = self.self(
  444. hidden_states,
  445. attention_mask,
  446. head_mask,
  447. encoder_hidden_states,
  448. encoder_attention_mask,
  449. past_key_value,
  450. output_attentions,
  451. )
  452. attention_output = self.output(self_outputs[0], hidden_states)
  453. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  454. return outputs
  455. class BertIntermediate(nn.Module):
  456. def __init__(self, config):
  457. super().__init__()
  458. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  459. if isinstance(config.hidden_act, str):
  460. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  461. else:
  462. self.intermediate_act_fn = config.hidden_act
  463. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  464. hidden_states = self.dense(hidden_states)
  465. hidden_states = self.intermediate_act_fn(hidden_states)
  466. return hidden_states
  467. class BertOutput(nn.Module):
  468. def __init__(self, config):
  469. super().__init__()
  470. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  471. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  472. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  473. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  474. hidden_states = self.dense(hidden_states)
  475. hidden_states = self.dropout(hidden_states)
  476. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  477. return hidden_states
  478. class BertLayer(nn.Module):
  479. def __init__(self, config):
  480. super().__init__()
  481. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  482. self.seq_len_dim = 1
  483. self.attention = BertAttention(config)
  484. self.is_decoder = config.is_decoder
  485. self.add_cross_attention = config.add_cross_attention
  486. if self.add_cross_attention:
  487. if not self.is_decoder:
  488. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  489. self.crossattention = BertAttention(config, position_embedding_type="absolute")
  490. self.intermediate = BertIntermediate(config)
  491. self.output = BertOutput(config)
  492. def forward(
  493. self,
  494. hidden_states: torch.Tensor,
  495. attention_mask: Optional[torch.FloatTensor] = None,
  496. head_mask: Optional[torch.FloatTensor] = None,
  497. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  498. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  499. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  500. output_attentions: Optional[bool] = False,
  501. ) -> Tuple[torch.Tensor]:
  502. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  503. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  504. self_attention_outputs = self.attention(
  505. hidden_states,
  506. attention_mask,
  507. head_mask,
  508. output_attentions=output_attentions,
  509. past_key_value=self_attn_past_key_value,
  510. )
  511. attention_output = self_attention_outputs[0]
  512. # if decoder, the last output is tuple of self-attn cache
  513. if self.is_decoder:
  514. outputs = self_attention_outputs[1:-1]
  515. present_key_value = self_attention_outputs[-1]
  516. else:
  517. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  518. cross_attn_present_key_value = None
  519. if self.is_decoder and encoder_hidden_states is not None:
  520. if not hasattr(self, "crossattention"):
  521. raise ValueError(
  522. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  523. " by setting `config.add_cross_attention=True`"
  524. )
  525. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  526. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  527. cross_attention_outputs = self.crossattention(
  528. attention_output,
  529. attention_mask,
  530. head_mask,
  531. encoder_hidden_states,
  532. encoder_attention_mask,
  533. cross_attn_past_key_value,
  534. output_attentions,
  535. )
  536. attention_output = cross_attention_outputs[0]
  537. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  538. # add cross-attn cache to positions 3,4 of present_key_value tuple
  539. cross_attn_present_key_value = cross_attention_outputs[-1]
  540. present_key_value = present_key_value + cross_attn_present_key_value
  541. layer_output = apply_chunking_to_forward(
  542. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  543. )
  544. outputs = (layer_output,) + outputs
  545. # if decoder, return the attn key/values as the last output
  546. if self.is_decoder:
  547. outputs = outputs + (present_key_value,)
  548. return outputs
  549. def feed_forward_chunk(self, attention_output):
  550. intermediate_output = self.intermediate(attention_output)
  551. layer_output = self.output(intermediate_output, attention_output)
  552. return layer_output
  553. class BertEncoder(nn.Module):
  554. def __init__(self, config):
  555. super().__init__()
  556. self.config = config
  557. self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
  558. self.gradient_checkpointing = False
  559. def forward(
  560. self,
  561. hidden_states: torch.Tensor,
  562. attention_mask: Optional[torch.FloatTensor] = None,
  563. head_mask: Optional[torch.FloatTensor] = None,
  564. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  565. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  566. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  567. use_cache: Optional[bool] = None,
  568. output_attentions: Optional[bool] = False,
  569. output_hidden_states: Optional[bool] = False,
  570. return_dict: Optional[bool] = True,
  571. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  572. all_hidden_states = () if output_hidden_states else None
  573. all_self_attentions = () if output_attentions else None
  574. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  575. if self.gradient_checkpointing and self.training:
  576. if use_cache:
  577. logger.warning_once(
  578. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  579. )
  580. use_cache = False
  581. next_decoder_cache = () if use_cache else None
  582. for i, layer_module in enumerate(self.layer):
  583. if output_hidden_states:
  584. all_hidden_states = all_hidden_states + (hidden_states,)
  585. layer_head_mask = head_mask[i] if head_mask is not None else None
  586. past_key_value = past_key_values[i] if past_key_values is not None else None
  587. if self.gradient_checkpointing and self.training:
  588. layer_outputs = self._gradient_checkpointing_func(
  589. layer_module.__call__,
  590. hidden_states,
  591. attention_mask,
  592. layer_head_mask,
  593. encoder_hidden_states,
  594. encoder_attention_mask,
  595. past_key_value,
  596. output_attentions,
  597. )
  598. else:
  599. layer_outputs = layer_module(
  600. hidden_states,
  601. attention_mask,
  602. layer_head_mask,
  603. encoder_hidden_states,
  604. encoder_attention_mask,
  605. past_key_value,
  606. output_attentions,
  607. )
  608. hidden_states = layer_outputs[0]
  609. if use_cache:
  610. next_decoder_cache += (layer_outputs[-1],)
  611. if output_attentions:
  612. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  613. if self.config.add_cross_attention:
  614. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  615. if output_hidden_states:
  616. all_hidden_states = all_hidden_states + (hidden_states,)
  617. if not return_dict:
  618. return tuple(
  619. v
  620. for v in [
  621. hidden_states,
  622. next_decoder_cache,
  623. all_hidden_states,
  624. all_self_attentions,
  625. all_cross_attentions,
  626. ]
  627. if v is not None
  628. )
  629. return BaseModelOutputWithPastAndCrossAttentions(
  630. last_hidden_state=hidden_states,
  631. past_key_values=next_decoder_cache,
  632. hidden_states=all_hidden_states,
  633. attentions=all_self_attentions,
  634. cross_attentions=all_cross_attentions,
  635. )
  636. class BertPooler(nn.Module):
  637. def __init__(self, config):
  638. super().__init__()
  639. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  640. self.activation = nn.Tanh()
  641. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  642. # We "pool" the model by simply taking the hidden state corresponding
  643. # to the first token.
  644. first_token_tensor = hidden_states[:, 0]
  645. pooled_output = self.dense(first_token_tensor)
  646. pooled_output = self.activation(pooled_output)
  647. return pooled_output
  648. class BertPredictionHeadTransform(nn.Module):
  649. def __init__(self, config):
  650. super().__init__()
  651. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  652. if isinstance(config.hidden_act, str):
  653. self.transform_act_fn = ACT2FN[config.hidden_act]
  654. else:
  655. self.transform_act_fn = config.hidden_act
  656. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  657. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  658. hidden_states = self.dense(hidden_states)
  659. hidden_states = self.transform_act_fn(hidden_states)
  660. hidden_states = self.LayerNorm(hidden_states)
  661. return hidden_states
  662. class BertLMPredictionHead(nn.Module):
  663. def __init__(self, config):
  664. super().__init__()
  665. self.transform = BertPredictionHeadTransform(config)
  666. # The output weights are the same as the input embeddings, but there is
  667. # an output-only bias for each token.
  668. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  669. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  670. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  671. self.decoder.bias = self.bias
  672. def _tie_weights(self):
  673. self.decoder.bias = self.bias
  674. def forward(self, hidden_states):
  675. hidden_states = self.transform(hidden_states)
  676. hidden_states = self.decoder(hidden_states)
  677. return hidden_states
  678. class BertOnlyMLMHead(nn.Module):
  679. def __init__(self, config):
  680. super().__init__()
  681. self.predictions = BertLMPredictionHead(config)
  682. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  683. prediction_scores = self.predictions(sequence_output)
  684. return prediction_scores
  685. class BertOnlyNSPHead(nn.Module):
  686. def __init__(self, config):
  687. super().__init__()
  688. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  689. def forward(self, pooled_output):
  690. seq_relationship_score = self.seq_relationship(pooled_output)
  691. return seq_relationship_score
  692. class BertPreTrainingHeads(nn.Module):
  693. def __init__(self, config):
  694. super().__init__()
  695. self.predictions = BertLMPredictionHead(config)
  696. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  697. def forward(self, sequence_output, pooled_output):
  698. prediction_scores = self.predictions(sequence_output)
  699. seq_relationship_score = self.seq_relationship(pooled_output)
  700. return prediction_scores, seq_relationship_score
  701. class BertPreTrainedModel(PreTrainedModel):
  702. """
  703. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  704. models.
  705. """
  706. config_class = BertConfig
  707. load_tf_weights = load_tf_weights_in_bert
  708. base_model_prefix = "bert"
  709. supports_gradient_checkpointing = True
  710. _supports_sdpa = True
  711. def _init_weights(self, module):
  712. """Initialize the weights"""
  713. if isinstance(module, nn.Linear):
  714. # Slightly different from the TF version which uses truncated_normal for initialization
  715. # cf https://github.com/pytorch/pytorch/pull/5617
  716. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  717. if module.bias is not None:
  718. module.bias.data.zero_()
  719. elif isinstance(module, nn.Embedding):
  720. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  721. if module.padding_idx is not None:
  722. module.weight.data[module.padding_idx].zero_()
  723. elif isinstance(module, nn.LayerNorm):
  724. module.bias.data.zero_()
  725. module.weight.data.fill_(1.0)
  726. @dataclass
  727. class BertForPreTrainingOutput(ModelOutput):
  728. """
  729. Output type of [`BertForPreTraining`].
  730. Args:
  731. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  732. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  733. (classification) loss.
  734. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  735. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  736. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  737. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  738. before SoftMax).
  739. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  740. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  741. shape `(batch_size, sequence_length, hidden_size)`.
  742. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  743. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  744. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  745. sequence_length)`.
  746. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  747. heads.
  748. """
  749. loss: Optional[torch.FloatTensor] = None
  750. prediction_logits: torch.FloatTensor = None
  751. seq_relationship_logits: torch.FloatTensor = None
  752. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  753. attentions: Optional[Tuple[torch.FloatTensor]] = None
  754. BERT_START_DOCSTRING = r"""
  755. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  756. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  757. etc.)
  758. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  759. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  760. and behavior.
  761. Parameters:
  762. config ([`BertConfig`]): Model configuration class with all the parameters of the model.
  763. Initializing with a config file does not load the weights associated with the model, only the
  764. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  765. """
  766. BERT_INPUTS_DOCSTRING = r"""
  767. Args:
  768. input_ids (`torch.LongTensor` of shape `({0})`):
  769. Indices of input sequence tokens in the vocabulary.
  770. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  771. [`PreTrainedTokenizer.__call__`] for details.
  772. [What are input IDs?](../glossary#input-ids)
  773. attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
  774. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  775. - 1 for tokens that are **not masked**,
  776. - 0 for tokens that are **masked**.
  777. [What are attention masks?](../glossary#attention-mask)
  778. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  779. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  780. 1]`:
  781. - 0 corresponds to a *sentence A* token,
  782. - 1 corresponds to a *sentence B* token.
  783. [What are token type IDs?](../glossary#token-type-ids)
  784. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  785. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  786. config.max_position_embeddings - 1]`.
  787. [What are position IDs?](../glossary#position-ids)
  788. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  789. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  790. - 1 indicates the head is **not masked**,
  791. - 0 indicates the head is **masked**.
  792. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  793. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  794. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  795. model's internal embedding lookup matrix.
  796. output_attentions (`bool`, *optional*):
  797. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  798. tensors for more detail.
  799. output_hidden_states (`bool`, *optional*):
  800. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  801. more detail.
  802. return_dict (`bool`, *optional*):
  803. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  804. """
  805. @add_start_docstrings(
  806. "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
  807. BERT_START_DOCSTRING,
  808. )
  809. class BertModel(BertPreTrainedModel):
  810. """
  811. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  812. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  813. all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  814. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  815. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  816. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  817. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  818. """
  819. _no_split_modules = ["BertEmbeddings", "BertLayer"]
  820. def __init__(self, config, add_pooling_layer=True):
  821. super().__init__(config)
  822. self.config = config
  823. self.embeddings = BertEmbeddings(config)
  824. self.encoder = BertEncoder(config)
  825. self.pooler = BertPooler(config) if add_pooling_layer else None
  826. self.attn_implementation = config._attn_implementation
  827. self.position_embedding_type = config.position_embedding_type
  828. # Initialize weights and apply final processing
  829. self.post_init()
  830. def get_input_embeddings(self):
  831. return self.embeddings.word_embeddings
  832. def set_input_embeddings(self, value):
  833. self.embeddings.word_embeddings = value
  834. def _prune_heads(self, heads_to_prune):
  835. """
  836. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  837. class PreTrainedModel
  838. """
  839. for layer, heads in heads_to_prune.items():
  840. self.encoder.layer[layer].attention.prune_heads(heads)
  841. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  842. @add_code_sample_docstrings(
  843. checkpoint=_CHECKPOINT_FOR_DOC,
  844. output_type=BaseModelOutputWithPoolingAndCrossAttentions,
  845. config_class=_CONFIG_FOR_DOC,
  846. )
  847. def forward(
  848. self,
  849. input_ids: Optional[torch.Tensor] = None,
  850. attention_mask: Optional[torch.Tensor] = None,
  851. token_type_ids: Optional[torch.Tensor] = None,
  852. position_ids: Optional[torch.Tensor] = None,
  853. head_mask: Optional[torch.Tensor] = None,
  854. inputs_embeds: Optional[torch.Tensor] = None,
  855. encoder_hidden_states: Optional[torch.Tensor] = None,
  856. encoder_attention_mask: Optional[torch.Tensor] = None,
  857. past_key_values: Optional[List[torch.FloatTensor]] = None,
  858. use_cache: Optional[bool] = None,
  859. output_attentions: Optional[bool] = None,
  860. output_hidden_states: Optional[bool] = None,
  861. return_dict: Optional[bool] = None,
  862. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  863. r"""
  864. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  865. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  866. the model is configured as a decoder.
  867. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
  868. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  869. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  870. - 1 for tokens that are **not masked**,
  871. - 0 for tokens that are **masked**.
  872. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  873. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  874. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  875. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  876. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  877. use_cache (`bool`, *optional*):
  878. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  879. `past_key_values`).
  880. """
  881. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  882. output_hidden_states = (
  883. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  884. )
  885. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  886. if self.config.is_decoder:
  887. use_cache = use_cache if use_cache is not None else self.config.use_cache
  888. else:
  889. use_cache = False
  890. if input_ids is not None and inputs_embeds is not None:
  891. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  892. elif input_ids is not None:
  893. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  894. input_shape = input_ids.size()
  895. elif inputs_embeds is not None:
  896. input_shape = inputs_embeds.size()[:-1]
  897. else:
  898. raise ValueError("You have to specify either input_ids or inputs_embeds")
  899. batch_size, seq_length = input_shape
  900. device = input_ids.device if input_ids is not None else inputs_embeds.device
  901. # past_key_values_length
  902. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  903. if token_type_ids is None:
  904. if hasattr(self.embeddings, "token_type_ids"):
  905. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  906. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  907. token_type_ids = buffered_token_type_ids_expanded
  908. else:
  909. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  910. embedding_output = self.embeddings(
  911. input_ids=input_ids,
  912. position_ids=position_ids,
  913. token_type_ids=token_type_ids,
  914. inputs_embeds=inputs_embeds,
  915. past_key_values_length=past_key_values_length,
  916. )
  917. if attention_mask is None:
  918. attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
  919. use_sdpa_attention_masks = (
  920. self.attn_implementation == "sdpa"
  921. and self.position_embedding_type == "absolute"
  922. and head_mask is None
  923. and not output_attentions
  924. )
  925. # Expand the attention mask
  926. if use_sdpa_attention_masks and attention_mask.dim() == 2:
  927. # Expand the attention mask for SDPA.
  928. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  929. if self.config.is_decoder:
  930. extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  931. attention_mask,
  932. input_shape,
  933. embedding_output,
  934. past_key_values_length,
  935. )
  936. else:
  937. extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  938. attention_mask, embedding_output.dtype, tgt_len=seq_length
  939. )
  940. else:
  941. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  942. # ourselves in which case we just need to make it broadcastable to all heads.
  943. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  944. # If a 2D or 3D attention mask is provided for the cross-attention
  945. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  946. if self.config.is_decoder and encoder_hidden_states is not None:
  947. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  948. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  949. if encoder_attention_mask is None:
  950. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  951. if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
  952. # Expand the attention mask for SDPA.
  953. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  954. encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  955. encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
  956. )
  957. else:
  958. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  959. else:
  960. encoder_extended_attention_mask = None
  961. # Prepare head mask if needed
  962. # 1.0 in head_mask indicate we keep the head
  963. # attention_probs has shape bsz x n_heads x N x N
  964. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  965. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  966. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  967. encoder_outputs = self.encoder(
  968. embedding_output,
  969. attention_mask=extended_attention_mask,
  970. head_mask=head_mask,
  971. encoder_hidden_states=encoder_hidden_states,
  972. encoder_attention_mask=encoder_extended_attention_mask,
  973. past_key_values=past_key_values,
  974. use_cache=use_cache,
  975. output_attentions=output_attentions,
  976. output_hidden_states=output_hidden_states,
  977. return_dict=return_dict,
  978. )
  979. sequence_output = encoder_outputs[0]
  980. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  981. if not return_dict:
  982. return (sequence_output, pooled_output) + encoder_outputs[1:]
  983. return BaseModelOutputWithPoolingAndCrossAttentions(
  984. last_hidden_state=sequence_output,
  985. pooler_output=pooled_output,
  986. past_key_values=encoder_outputs.past_key_values,
  987. hidden_states=encoder_outputs.hidden_states,
  988. attentions=encoder_outputs.attentions,
  989. cross_attentions=encoder_outputs.cross_attentions,
  990. )
  991. @add_start_docstrings(
  992. """
  993. Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  994. sentence prediction (classification)` head.
  995. """,
  996. BERT_START_DOCSTRING,
  997. )
  998. class BertForPreTraining(BertPreTrainedModel):
  999. _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
  1000. def __init__(self, config):
  1001. super().__init__(config)
  1002. self.bert = BertModel(config)
  1003. self.cls = BertPreTrainingHeads(config)
  1004. # Initialize weights and apply final processing
  1005. self.post_init()
  1006. def get_output_embeddings(self):
  1007. return self.cls.predictions.decoder
  1008. def set_output_embeddings(self, new_embeddings):
  1009. self.cls.predictions.decoder = new_embeddings
  1010. self.cls.predictions.bias = new_embeddings.bias
  1011. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1012. @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  1013. def forward(
  1014. self,
  1015. input_ids: Optional[torch.Tensor] = None,
  1016. attention_mask: Optional[torch.Tensor] = None,
  1017. token_type_ids: Optional[torch.Tensor] = None,
  1018. position_ids: Optional[torch.Tensor] = None,
  1019. head_mask: Optional[torch.Tensor] = None,
  1020. inputs_embeds: Optional[torch.Tensor] = None,
  1021. labels: Optional[torch.Tensor] = None,
  1022. next_sentence_label: Optional[torch.Tensor] = None,
  1023. output_attentions: Optional[bool] = None,
  1024. output_hidden_states: Optional[bool] = None,
  1025. return_dict: Optional[bool] = None,
  1026. ) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
  1027. r"""
  1028. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1029. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1030. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
  1031. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1032. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1033. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
  1034. pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  1035. - 0 indicates sequence B is a continuation of sequence A,
  1036. - 1 indicates sequence B is a random sequence.
  1037. kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
  1038. Used to hide legacy arguments that have been deprecated.
  1039. Returns:
  1040. Example:
  1041. ```python
  1042. >>> from transformers import AutoTokenizer, BertForPreTraining
  1043. >>> import torch
  1044. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1045. >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
  1046. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1047. >>> outputs = model(**inputs)
  1048. >>> prediction_logits = outputs.prediction_logits
  1049. >>> seq_relationship_logits = outputs.seq_relationship_logits
  1050. ```
  1051. """
  1052. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1053. outputs = self.bert(
  1054. input_ids,
  1055. attention_mask=attention_mask,
  1056. token_type_ids=token_type_ids,
  1057. position_ids=position_ids,
  1058. head_mask=head_mask,
  1059. inputs_embeds=inputs_embeds,
  1060. output_attentions=output_attentions,
  1061. output_hidden_states=output_hidden_states,
  1062. return_dict=return_dict,
  1063. )
  1064. sequence_output, pooled_output = outputs[:2]
  1065. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  1066. total_loss = None
  1067. if labels is not None and next_sentence_label is not None:
  1068. loss_fct = CrossEntropyLoss()
  1069. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  1070. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  1071. total_loss = masked_lm_loss + next_sentence_loss
  1072. if not return_dict:
  1073. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  1074. return ((total_loss,) + output) if total_loss is not None else output
  1075. return BertForPreTrainingOutput(
  1076. loss=total_loss,
  1077. prediction_logits=prediction_scores,
  1078. seq_relationship_logits=seq_relationship_score,
  1079. hidden_states=outputs.hidden_states,
  1080. attentions=outputs.attentions,
  1081. )
  1082. @add_start_docstrings(
  1083. """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
  1084. )
  1085. class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
  1086. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  1087. def __init__(self, config):
  1088. super().__init__(config)
  1089. if not config.is_decoder:
  1090. logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
  1091. self.bert = BertModel(config, add_pooling_layer=False)
  1092. self.cls = BertOnlyMLMHead(config)
  1093. # Initialize weights and apply final processing
  1094. self.post_init()
  1095. def get_output_embeddings(self):
  1096. return self.cls.predictions.decoder
  1097. def set_output_embeddings(self, new_embeddings):
  1098. self.cls.predictions.decoder = new_embeddings
  1099. self.cls.predictions.bias = new_embeddings.bias
  1100. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1101. @add_code_sample_docstrings(
  1102. checkpoint=_CHECKPOINT_FOR_DOC,
  1103. output_type=CausalLMOutputWithCrossAttentions,
  1104. config_class=_CONFIG_FOR_DOC,
  1105. )
  1106. def forward(
  1107. self,
  1108. input_ids: Optional[torch.Tensor] = None,
  1109. attention_mask: Optional[torch.Tensor] = None,
  1110. token_type_ids: Optional[torch.Tensor] = None,
  1111. position_ids: Optional[torch.Tensor] = None,
  1112. head_mask: Optional[torch.Tensor] = None,
  1113. inputs_embeds: Optional[torch.Tensor] = None,
  1114. encoder_hidden_states: Optional[torch.Tensor] = None,
  1115. encoder_attention_mask: Optional[torch.Tensor] = None,
  1116. labels: Optional[torch.Tensor] = None,
  1117. past_key_values: Optional[List[torch.Tensor]] = None,
  1118. use_cache: Optional[bool] = None,
  1119. output_attentions: Optional[bool] = None,
  1120. output_hidden_states: Optional[bool] = None,
  1121. return_dict: Optional[bool] = None,
  1122. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  1123. r"""
  1124. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1125. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  1126. the model is configured as a decoder.
  1127. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1128. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  1129. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1130. - 1 for tokens that are **not masked**,
  1131. - 0 for tokens that are **masked**.
  1132. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1133. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1134. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1135. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1136. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1137. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1138. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1139. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1140. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1141. use_cache (`bool`, *optional*):
  1142. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1143. `past_key_values`).
  1144. """
  1145. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1146. if labels is not None:
  1147. use_cache = False
  1148. outputs = self.bert(
  1149. input_ids,
  1150. attention_mask=attention_mask,
  1151. token_type_ids=token_type_ids,
  1152. position_ids=position_ids,
  1153. head_mask=head_mask,
  1154. inputs_embeds=inputs_embeds,
  1155. encoder_hidden_states=encoder_hidden_states,
  1156. encoder_attention_mask=encoder_attention_mask,
  1157. past_key_values=past_key_values,
  1158. use_cache=use_cache,
  1159. output_attentions=output_attentions,
  1160. output_hidden_states=output_hidden_states,
  1161. return_dict=return_dict,
  1162. )
  1163. sequence_output = outputs[0]
  1164. prediction_scores = self.cls(sequence_output)
  1165. lm_loss = None
  1166. if labels is not None:
  1167. # we are doing next-token prediction; shift prediction scores and input ids by one
  1168. shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
  1169. labels = labels[:, 1:].contiguous()
  1170. loss_fct = CrossEntropyLoss()
  1171. lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  1172. if not return_dict:
  1173. output = (prediction_scores,) + outputs[2:]
  1174. return ((lm_loss,) + output) if lm_loss is not None else output
  1175. return CausalLMOutputWithCrossAttentions(
  1176. loss=lm_loss,
  1177. logits=prediction_scores,
  1178. past_key_values=outputs.past_key_values,
  1179. hidden_states=outputs.hidden_states,
  1180. attentions=outputs.attentions,
  1181. cross_attentions=outputs.cross_attentions,
  1182. )
  1183. def _reorder_cache(self, past_key_values, beam_idx):
  1184. reordered_past = ()
  1185. for layer_past in past_key_values:
  1186. reordered_past += (
  1187. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1188. )
  1189. return reordered_past
  1190. @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
  1191. class BertForMaskedLM(BertPreTrainedModel):
  1192. _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
  1193. def __init__(self, config):
  1194. super().__init__(config)
  1195. if config.is_decoder:
  1196. logger.warning(
  1197. "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
  1198. "bi-directional self-attention."
  1199. )
  1200. self.bert = BertModel(config, add_pooling_layer=False)
  1201. self.cls = BertOnlyMLMHead(config)
  1202. # Initialize weights and apply final processing
  1203. self.post_init()
  1204. def get_output_embeddings(self):
  1205. return self.cls.predictions.decoder
  1206. def set_output_embeddings(self, new_embeddings):
  1207. self.cls.predictions.decoder = new_embeddings
  1208. self.cls.predictions.bias = new_embeddings.bias
  1209. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1210. @add_code_sample_docstrings(
  1211. checkpoint=_CHECKPOINT_FOR_DOC,
  1212. output_type=MaskedLMOutput,
  1213. config_class=_CONFIG_FOR_DOC,
  1214. expected_output="'paris'",
  1215. expected_loss=0.88,
  1216. )
  1217. def forward(
  1218. self,
  1219. input_ids: Optional[torch.Tensor] = None,
  1220. attention_mask: Optional[torch.Tensor] = None,
  1221. token_type_ids: Optional[torch.Tensor] = None,
  1222. position_ids: Optional[torch.Tensor] = None,
  1223. head_mask: Optional[torch.Tensor] = None,
  1224. inputs_embeds: Optional[torch.Tensor] = None,
  1225. encoder_hidden_states: Optional[torch.Tensor] = None,
  1226. encoder_attention_mask: Optional[torch.Tensor] = None,
  1227. labels: Optional[torch.Tensor] = None,
  1228. output_attentions: Optional[bool] = None,
  1229. output_hidden_states: Optional[bool] = None,
  1230. return_dict: Optional[bool] = None,
  1231. ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
  1232. r"""
  1233. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1234. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1235. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1236. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1237. """
  1238. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1239. outputs = self.bert(
  1240. input_ids,
  1241. attention_mask=attention_mask,
  1242. token_type_ids=token_type_ids,
  1243. position_ids=position_ids,
  1244. head_mask=head_mask,
  1245. inputs_embeds=inputs_embeds,
  1246. encoder_hidden_states=encoder_hidden_states,
  1247. encoder_attention_mask=encoder_attention_mask,
  1248. output_attentions=output_attentions,
  1249. output_hidden_states=output_hidden_states,
  1250. return_dict=return_dict,
  1251. )
  1252. sequence_output = outputs[0]
  1253. prediction_scores = self.cls(sequence_output)
  1254. masked_lm_loss = None
  1255. if labels is not None:
  1256. loss_fct = CrossEntropyLoss() # -100 index = padding token
  1257. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  1258. if not return_dict:
  1259. output = (prediction_scores,) + outputs[2:]
  1260. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1261. return MaskedLMOutput(
  1262. loss=masked_lm_loss,
  1263. logits=prediction_scores,
  1264. hidden_states=outputs.hidden_states,
  1265. attentions=outputs.attentions,
  1266. )
  1267. def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
  1268. input_shape = input_ids.shape
  1269. effective_batch_size = input_shape[0]
  1270. # add a dummy token
  1271. if self.config.pad_token_id is None:
  1272. raise ValueError("The PAD token should be defined for generation")
  1273. attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
  1274. dummy_token = torch.full(
  1275. (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
  1276. )
  1277. input_ids = torch.cat([input_ids, dummy_token], dim=1)
  1278. return {"input_ids": input_ids, "attention_mask": attention_mask}
  1279. @add_start_docstrings(
  1280. """Bert Model with a `next sentence prediction (classification)` head on top.""",
  1281. BERT_START_DOCSTRING,
  1282. )
  1283. class BertForNextSentencePrediction(BertPreTrainedModel):
  1284. def __init__(self, config):
  1285. super().__init__(config)
  1286. self.bert = BertModel(config)
  1287. self.cls = BertOnlyNSPHead(config)
  1288. # Initialize weights and apply final processing
  1289. self.post_init()
  1290. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1291. @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
  1292. def forward(
  1293. self,
  1294. input_ids: Optional[torch.Tensor] = None,
  1295. attention_mask: Optional[torch.Tensor] = None,
  1296. token_type_ids: Optional[torch.Tensor] = None,
  1297. position_ids: Optional[torch.Tensor] = None,
  1298. head_mask: Optional[torch.Tensor] = None,
  1299. inputs_embeds: Optional[torch.Tensor] = None,
  1300. labels: Optional[torch.Tensor] = None,
  1301. output_attentions: Optional[bool] = None,
  1302. output_hidden_states: Optional[bool] = None,
  1303. return_dict: Optional[bool] = None,
  1304. **kwargs,
  1305. ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
  1306. r"""
  1307. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1308. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  1309. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  1310. - 0 indicates sequence B is a continuation of sequence A,
  1311. - 1 indicates sequence B is a random sequence.
  1312. Returns:
  1313. Example:
  1314. ```python
  1315. >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
  1316. >>> import torch
  1317. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1318. >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
  1319. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1320. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  1321. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  1322. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  1323. >>> logits = outputs.logits
  1324. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  1325. ```
  1326. """
  1327. if "next_sentence_label" in kwargs:
  1328. warnings.warn(
  1329. "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
  1330. " `labels` instead.",
  1331. FutureWarning,
  1332. )
  1333. labels = kwargs.pop("next_sentence_label")
  1334. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1335. outputs = self.bert(
  1336. input_ids,
  1337. attention_mask=attention_mask,
  1338. token_type_ids=token_type_ids,
  1339. position_ids=position_ids,
  1340. head_mask=head_mask,
  1341. inputs_embeds=inputs_embeds,
  1342. output_attentions=output_attentions,
  1343. output_hidden_states=output_hidden_states,
  1344. return_dict=return_dict,
  1345. )
  1346. pooled_output = outputs[1]
  1347. seq_relationship_scores = self.cls(pooled_output)
  1348. next_sentence_loss = None
  1349. if labels is not None:
  1350. loss_fct = CrossEntropyLoss()
  1351. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  1352. if not return_dict:
  1353. output = (seq_relationship_scores,) + outputs[2:]
  1354. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  1355. return NextSentencePredictorOutput(
  1356. loss=next_sentence_loss,
  1357. logits=seq_relationship_scores,
  1358. hidden_states=outputs.hidden_states,
  1359. attentions=outputs.attentions,
  1360. )
  1361. @add_start_docstrings(
  1362. """
  1363. Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  1364. output) e.g. for GLUE tasks.
  1365. """,
  1366. BERT_START_DOCSTRING,
  1367. )
  1368. class BertForSequenceClassification(BertPreTrainedModel):
  1369. def __init__(self, config):
  1370. super().__init__(config)
  1371. self.num_labels = config.num_labels
  1372. self.config = config
  1373. self.bert = BertModel(config)
  1374. classifier_dropout = (
  1375. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1376. )
  1377. self.dropout = nn.Dropout(classifier_dropout)
  1378. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1379. # Initialize weights and apply final processing
  1380. self.post_init()
  1381. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1382. @add_code_sample_docstrings(
  1383. checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
  1384. output_type=SequenceClassifierOutput,
  1385. config_class=_CONFIG_FOR_DOC,
  1386. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1387. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1388. )
  1389. def forward(
  1390. self,
  1391. input_ids: Optional[torch.Tensor] = None,
  1392. attention_mask: Optional[torch.Tensor] = None,
  1393. token_type_ids: Optional[torch.Tensor] = None,
  1394. position_ids: Optional[torch.Tensor] = None,
  1395. head_mask: Optional[torch.Tensor] = None,
  1396. inputs_embeds: Optional[torch.Tensor] = None,
  1397. labels: Optional[torch.Tensor] = None,
  1398. output_attentions: Optional[bool] = None,
  1399. output_hidden_states: Optional[bool] = None,
  1400. return_dict: Optional[bool] = None,
  1401. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
  1402. r"""
  1403. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1404. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1405. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1406. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1407. """
  1408. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1409. outputs = self.bert(
  1410. input_ids,
  1411. attention_mask=attention_mask,
  1412. token_type_ids=token_type_ids,
  1413. position_ids=position_ids,
  1414. head_mask=head_mask,
  1415. inputs_embeds=inputs_embeds,
  1416. output_attentions=output_attentions,
  1417. output_hidden_states=output_hidden_states,
  1418. return_dict=return_dict,
  1419. )
  1420. pooled_output = outputs[1]
  1421. pooled_output = self.dropout(pooled_output)
  1422. logits = self.classifier(pooled_output)
  1423. loss = None
  1424. if labels is not None:
  1425. if self.config.problem_type is None:
  1426. if self.num_labels == 1:
  1427. self.config.problem_type = "regression"
  1428. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1429. self.config.problem_type = "single_label_classification"
  1430. else:
  1431. self.config.problem_type = "multi_label_classification"
  1432. if self.config.problem_type == "regression":
  1433. loss_fct = MSELoss()
  1434. if self.num_labels == 1:
  1435. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1436. else:
  1437. loss = loss_fct(logits, labels)
  1438. elif self.config.problem_type == "single_label_classification":
  1439. loss_fct = CrossEntropyLoss()
  1440. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1441. elif self.config.problem_type == "multi_label_classification":
  1442. loss_fct = BCEWithLogitsLoss()
  1443. loss = loss_fct(logits, labels)
  1444. if not return_dict:
  1445. output = (logits,) + outputs[2:]
  1446. return ((loss,) + output) if loss is not None else output
  1447. return SequenceClassifierOutput(
  1448. loss=loss,
  1449. logits=logits,
  1450. hidden_states=outputs.hidden_states,
  1451. attentions=outputs.attentions,
  1452. )
  1453. @add_start_docstrings(
  1454. """
  1455. Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1456. softmax) e.g. for RocStories/SWAG tasks.
  1457. """,
  1458. BERT_START_DOCSTRING,
  1459. )
  1460. class BertForMultipleChoice(BertPreTrainedModel):
  1461. def __init__(self, config):
  1462. super().__init__(config)
  1463. self.bert = BertModel(config)
  1464. classifier_dropout = (
  1465. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1466. )
  1467. self.dropout = nn.Dropout(classifier_dropout)
  1468. self.classifier = nn.Linear(config.hidden_size, 1)
  1469. # Initialize weights and apply final processing
  1470. self.post_init()
  1471. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
  1472. @add_code_sample_docstrings(
  1473. checkpoint=_CHECKPOINT_FOR_DOC,
  1474. output_type=MultipleChoiceModelOutput,
  1475. config_class=_CONFIG_FOR_DOC,
  1476. )
  1477. def forward(
  1478. self,
  1479. input_ids: Optional[torch.Tensor] = None,
  1480. attention_mask: Optional[torch.Tensor] = None,
  1481. token_type_ids: Optional[torch.Tensor] = None,
  1482. position_ids: Optional[torch.Tensor] = None,
  1483. head_mask: Optional[torch.Tensor] = None,
  1484. inputs_embeds: Optional[torch.Tensor] = None,
  1485. labels: Optional[torch.Tensor] = None,
  1486. output_attentions: Optional[bool] = None,
  1487. output_hidden_states: Optional[bool] = None,
  1488. return_dict: Optional[bool] = None,
  1489. ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1490. r"""
  1491. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1492. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1493. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1494. `input_ids` above)
  1495. """
  1496. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1497. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1498. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1499. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1500. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1501. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1502. inputs_embeds = (
  1503. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1504. if inputs_embeds is not None
  1505. else None
  1506. )
  1507. outputs = self.bert(
  1508. input_ids,
  1509. attention_mask=attention_mask,
  1510. token_type_ids=token_type_ids,
  1511. position_ids=position_ids,
  1512. head_mask=head_mask,
  1513. inputs_embeds=inputs_embeds,
  1514. output_attentions=output_attentions,
  1515. output_hidden_states=output_hidden_states,
  1516. return_dict=return_dict,
  1517. )
  1518. pooled_output = outputs[1]
  1519. pooled_output = self.dropout(pooled_output)
  1520. logits = self.classifier(pooled_output)
  1521. reshaped_logits = logits.view(-1, num_choices)
  1522. loss = None
  1523. if labels is not None:
  1524. loss_fct = CrossEntropyLoss()
  1525. loss = loss_fct(reshaped_logits, labels)
  1526. if not return_dict:
  1527. output = (reshaped_logits,) + outputs[2:]
  1528. return ((loss,) + output) if loss is not None else output
  1529. return MultipleChoiceModelOutput(
  1530. loss=loss,
  1531. logits=reshaped_logits,
  1532. hidden_states=outputs.hidden_states,
  1533. attentions=outputs.attentions,
  1534. )
  1535. @add_start_docstrings(
  1536. """
  1537. Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1538. Named-Entity-Recognition (NER) tasks.
  1539. """,
  1540. BERT_START_DOCSTRING,
  1541. )
  1542. class BertForTokenClassification(BertPreTrainedModel):
  1543. def __init__(self, config):
  1544. super().__init__(config)
  1545. self.num_labels = config.num_labels
  1546. self.bert = BertModel(config, add_pooling_layer=False)
  1547. classifier_dropout = (
  1548. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1549. )
  1550. self.dropout = nn.Dropout(classifier_dropout)
  1551. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1552. # Initialize weights and apply final processing
  1553. self.post_init()
  1554. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1555. @add_code_sample_docstrings(
  1556. checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
  1557. output_type=TokenClassifierOutput,
  1558. config_class=_CONFIG_FOR_DOC,
  1559. expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
  1560. expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
  1561. )
  1562. def forward(
  1563. self,
  1564. input_ids: Optional[torch.Tensor] = None,
  1565. attention_mask: Optional[torch.Tensor] = None,
  1566. token_type_ids: Optional[torch.Tensor] = None,
  1567. position_ids: Optional[torch.Tensor] = None,
  1568. head_mask: Optional[torch.Tensor] = None,
  1569. inputs_embeds: Optional[torch.Tensor] = None,
  1570. labels: Optional[torch.Tensor] = None,
  1571. output_attentions: Optional[bool] = None,
  1572. output_hidden_states: Optional[bool] = None,
  1573. return_dict: Optional[bool] = None,
  1574. ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
  1575. r"""
  1576. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1577. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1578. """
  1579. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1580. outputs = self.bert(
  1581. input_ids,
  1582. attention_mask=attention_mask,
  1583. token_type_ids=token_type_ids,
  1584. position_ids=position_ids,
  1585. head_mask=head_mask,
  1586. inputs_embeds=inputs_embeds,
  1587. output_attentions=output_attentions,
  1588. output_hidden_states=output_hidden_states,
  1589. return_dict=return_dict,
  1590. )
  1591. sequence_output = outputs[0]
  1592. sequence_output = self.dropout(sequence_output)
  1593. logits = self.classifier(sequence_output)
  1594. loss = None
  1595. if labels is not None:
  1596. loss_fct = CrossEntropyLoss()
  1597. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1598. if not return_dict:
  1599. output = (logits,) + outputs[2:]
  1600. return ((loss,) + output) if loss is not None else output
  1601. return TokenClassifierOutput(
  1602. loss=loss,
  1603. logits=logits,
  1604. hidden_states=outputs.hidden_states,
  1605. attentions=outputs.attentions,
  1606. )
  1607. @add_start_docstrings(
  1608. """
  1609. Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1610. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1611. """,
  1612. BERT_START_DOCSTRING,
  1613. )
  1614. class BertForQuestionAnswering(BertPreTrainedModel):
  1615. def __init__(self, config):
  1616. super().__init__(config)
  1617. self.num_labels = config.num_labels
  1618. self.bert = BertModel(config, add_pooling_layer=False)
  1619. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1620. # Initialize weights and apply final processing
  1621. self.post_init()
  1622. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1623. @add_code_sample_docstrings(
  1624. checkpoint=_CHECKPOINT_FOR_QA,
  1625. output_type=QuestionAnsweringModelOutput,
  1626. config_class=_CONFIG_FOR_DOC,
  1627. qa_target_start_index=_QA_TARGET_START_INDEX,
  1628. qa_target_end_index=_QA_TARGET_END_INDEX,
  1629. expected_output=_QA_EXPECTED_OUTPUT,
  1630. expected_loss=_QA_EXPECTED_LOSS,
  1631. )
  1632. def forward(
  1633. self,
  1634. input_ids: Optional[torch.Tensor] = None,
  1635. attention_mask: Optional[torch.Tensor] = None,
  1636. token_type_ids: Optional[torch.Tensor] = None,
  1637. position_ids: Optional[torch.Tensor] = None,
  1638. head_mask: Optional[torch.Tensor] = None,
  1639. inputs_embeds: Optional[torch.Tensor] = None,
  1640. start_positions: Optional[torch.Tensor] = None,
  1641. end_positions: Optional[torch.Tensor] = None,
  1642. output_attentions: Optional[bool] = None,
  1643. output_hidden_states: Optional[bool] = None,
  1644. return_dict: Optional[bool] = None,
  1645. ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1646. r"""
  1647. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1648. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1649. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1650. are not taken into account for computing the loss.
  1651. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1652. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1653. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1654. are not taken into account for computing the loss.
  1655. """
  1656. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1657. outputs = self.bert(
  1658. input_ids,
  1659. attention_mask=attention_mask,
  1660. token_type_ids=token_type_ids,
  1661. position_ids=position_ids,
  1662. head_mask=head_mask,
  1663. inputs_embeds=inputs_embeds,
  1664. output_attentions=output_attentions,
  1665. output_hidden_states=output_hidden_states,
  1666. return_dict=return_dict,
  1667. )
  1668. sequence_output = outputs[0]
  1669. logits = self.qa_outputs(sequence_output)
  1670. start_logits, end_logits = logits.split(1, dim=-1)
  1671. start_logits = start_logits.squeeze(-1).contiguous()
  1672. end_logits = end_logits.squeeze(-1).contiguous()
  1673. total_loss = None
  1674. if start_positions is not None and end_positions is not None:
  1675. # If we are on multi-GPU, split add a dimension
  1676. if len(start_positions.size()) > 1:
  1677. start_positions = start_positions.squeeze(-1)
  1678. if len(end_positions.size()) > 1:
  1679. end_positions = end_positions.squeeze(-1)
  1680. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1681. ignored_index = start_logits.size(1)
  1682. start_positions = start_positions.clamp(0, ignored_index)
  1683. end_positions = end_positions.clamp(0, ignored_index)
  1684. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1685. start_loss = loss_fct(start_logits, start_positions)
  1686. end_loss = loss_fct(end_logits, end_positions)
  1687. total_loss = (start_loss + end_loss) / 2
  1688. if not return_dict:
  1689. output = (start_logits, end_logits) + outputs[2:]
  1690. return ((total_loss,) + output) if total_loss is not None else output
  1691. return QuestionAnsweringModelOutput(
  1692. loss=total_loss,
  1693. start_logits=start_logits,
  1694. end_logits=end_logits,
  1695. hidden_states=outputs.hidden_states,
  1696. attentions=outputs.attentions,
  1697. )