modeling_tf_bert.py 92 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF 2.0 BERT model."""
  17. from __future__ import annotations
  18. import math
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import Dict, Optional, Tuple, Union
  22. import numpy as np
  23. import tensorflow as tf
  24. from ...activations_tf import get_tf_activation
  25. from ...modeling_tf_outputs import (
  26. TFBaseModelOutputWithPastAndCrossAttentions,
  27. TFBaseModelOutputWithPoolingAndCrossAttentions,
  28. TFCausalLMOutputWithCrossAttentions,
  29. TFMaskedLMOutput,
  30. TFMultipleChoiceModelOutput,
  31. TFNextSentencePredictorOutput,
  32. TFQuestionAnsweringModelOutput,
  33. TFSequenceClassifierOutput,
  34. TFTokenClassifierOutput,
  35. )
  36. from ...modeling_tf_utils import (
  37. TFCausalLanguageModelingLoss,
  38. TFMaskedLanguageModelingLoss,
  39. TFModelInputType,
  40. TFMultipleChoiceLoss,
  41. TFNextSentencePredictionLoss,
  42. TFPreTrainedModel,
  43. TFQuestionAnsweringLoss,
  44. TFSequenceClassificationLoss,
  45. TFTokenClassificationLoss,
  46. get_initializer,
  47. keras,
  48. keras_serializable,
  49. unpack_inputs,
  50. )
  51. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  52. from ...utils import (
  53. ModelOutput,
  54. add_code_sample_docstrings,
  55. add_start_docstrings,
  56. add_start_docstrings_to_model_forward,
  57. logging,
  58. replace_return_docstrings,
  59. )
  60. from .configuration_bert import BertConfig
  61. logger = logging.get_logger(__name__)
  62. _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
  63. _CONFIG_FOR_DOC = "BertConfig"
  64. # TokenClassification docstring
  65. _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
  66. _TOKEN_CLASS_EXPECTED_OUTPUT = (
  67. "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
  68. )
  69. _TOKEN_CLASS_EXPECTED_LOSS = 0.01
  70. # QuestionAnswering docstring
  71. _CHECKPOINT_FOR_QA = "ydshieh/bert-base-cased-squad2"
  72. _QA_EXPECTED_OUTPUT = "'a nice puppet'"
  73. _QA_EXPECTED_LOSS = 7.41
  74. _QA_TARGET_START_INDEX = 14
  75. _QA_TARGET_END_INDEX = 15
  76. # SequenceClassification docstring
  77. _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/bert-base-uncased-yelp-polarity"
  78. _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
  79. _SEQ_CLASS_EXPECTED_LOSS = 0.01
  80. class TFBertPreTrainingLoss:
  81. """
  82. Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining
  83. NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
  84. computation.
  85. """
  86. def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
  87. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  88. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  89. unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
  90. # make sure only labels that are not equal to -100
  91. # are taken into account for the loss computation
  92. lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
  93. masked_lm_losses = unmasked_lm_losses * lm_loss_mask
  94. reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
  95. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  96. unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
  97. ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
  98. masked_ns_loss = unmasked_ns_loss * ns_loss_mask
  99. reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask)
  100. return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,))
  101. class TFBertEmbeddings(keras.layers.Layer):
  102. """Construct the embeddings from word, position and token_type embeddings."""
  103. def __init__(self, config: BertConfig, **kwargs):
  104. super().__init__(**kwargs)
  105. self.config = config
  106. self.hidden_size = config.hidden_size
  107. self.max_position_embeddings = config.max_position_embeddings
  108. self.initializer_range = config.initializer_range
  109. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  110. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  111. def build(self, input_shape=None):
  112. with tf.name_scope("word_embeddings"):
  113. self.weight = self.add_weight(
  114. name="weight",
  115. shape=[self.config.vocab_size, self.hidden_size],
  116. initializer=get_initializer(self.initializer_range),
  117. )
  118. with tf.name_scope("token_type_embeddings"):
  119. self.token_type_embeddings = self.add_weight(
  120. name="embeddings",
  121. shape=[self.config.type_vocab_size, self.hidden_size],
  122. initializer=get_initializer(self.initializer_range),
  123. )
  124. with tf.name_scope("position_embeddings"):
  125. self.position_embeddings = self.add_weight(
  126. name="embeddings",
  127. shape=[self.max_position_embeddings, self.hidden_size],
  128. initializer=get_initializer(self.initializer_range),
  129. )
  130. if self.built:
  131. return
  132. self.built = True
  133. if getattr(self, "LayerNorm", None) is not None:
  134. with tf.name_scope(self.LayerNorm.name):
  135. self.LayerNorm.build([None, None, self.config.hidden_size])
  136. def call(
  137. self,
  138. input_ids: tf.Tensor = None,
  139. position_ids: tf.Tensor = None,
  140. token_type_ids: tf.Tensor = None,
  141. inputs_embeds: tf.Tensor = None,
  142. past_key_values_length=0,
  143. training: bool = False,
  144. ) -> tf.Tensor:
  145. """
  146. Applies embedding based on inputs tensor.
  147. Returns:
  148. final_embeddings (`tf.Tensor`): output embedding tensor.
  149. """
  150. if input_ids is None and inputs_embeds is None:
  151. raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
  152. if input_ids is not None:
  153. check_embeddings_within_bounds(input_ids, self.config.vocab_size)
  154. inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
  155. input_shape = shape_list(inputs_embeds)[:-1]
  156. if token_type_ids is None:
  157. token_type_ids = tf.fill(dims=input_shape, value=0)
  158. if position_ids is None:
  159. position_ids = tf.expand_dims(
  160. tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
  161. )
  162. position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
  163. token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
  164. final_embeddings = inputs_embeds + position_embeds + token_type_embeds
  165. final_embeddings = self.LayerNorm(inputs=final_embeddings)
  166. final_embeddings = self.dropout(inputs=final_embeddings, training=training)
  167. return final_embeddings
  168. class TFBertSelfAttention(keras.layers.Layer):
  169. def __init__(self, config: BertConfig, **kwargs):
  170. super().__init__(**kwargs)
  171. if config.hidden_size % config.num_attention_heads != 0:
  172. raise ValueError(
  173. f"The hidden size ({config.hidden_size}) is not a multiple of the number "
  174. f"of attention heads ({config.num_attention_heads})"
  175. )
  176. self.num_attention_heads = config.num_attention_heads
  177. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  178. self.all_head_size = self.num_attention_heads * self.attention_head_size
  179. self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
  180. self.query = keras.layers.Dense(
  181. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
  182. )
  183. self.key = keras.layers.Dense(
  184. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
  185. )
  186. self.value = keras.layers.Dense(
  187. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
  188. )
  189. self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
  190. self.is_decoder = config.is_decoder
  191. self.config = config
  192. def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
  193. # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
  194. tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
  195. # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
  196. return tf.transpose(tensor, perm=[0, 2, 1, 3])
  197. def call(
  198. self,
  199. hidden_states: tf.Tensor,
  200. attention_mask: tf.Tensor,
  201. head_mask: tf.Tensor,
  202. encoder_hidden_states: tf.Tensor,
  203. encoder_attention_mask: tf.Tensor,
  204. past_key_value: Tuple[tf.Tensor],
  205. output_attentions: bool,
  206. training: bool = False,
  207. ) -> Tuple[tf.Tensor]:
  208. batch_size = shape_list(hidden_states)[0]
  209. mixed_query_layer = self.query(inputs=hidden_states)
  210. # If this is instantiated as a cross-attention module, the keys
  211. # and values come from an encoder; the attention mask needs to be
  212. # such that the encoder's padding tokens are not attended to.
  213. is_cross_attention = encoder_hidden_states is not None
  214. if is_cross_attention and past_key_value is not None:
  215. # reuse k,v, cross_attentions
  216. key_layer = past_key_value[0]
  217. value_layer = past_key_value[1]
  218. attention_mask = encoder_attention_mask
  219. elif is_cross_attention:
  220. key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
  221. value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
  222. attention_mask = encoder_attention_mask
  223. elif past_key_value is not None:
  224. key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
  225. value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
  226. key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
  227. value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
  228. else:
  229. key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
  230. value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
  231. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  232. if self.is_decoder:
  233. # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
  234. # Further calls to cross_attention layer can then reuse all cross-attention
  235. # key/value_states (first "if" case)
  236. # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
  237. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  238. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  239. # if encoder bi-directional self-attention `past_key_value` is always `None`
  240. past_key_value = (key_layer, value_layer)
  241. # Take the dot product between "query" and "key" to get the raw attention scores.
  242. # (batch size, num_heads, seq_len_q, seq_len_k)
  243. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  244. dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
  245. attention_scores = tf.divide(attention_scores, dk)
  246. if attention_mask is not None:
  247. # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
  248. attention_scores = tf.add(attention_scores, attention_mask)
  249. # Normalize the attention scores to probabilities.
  250. attention_probs = stable_softmax(logits=attention_scores, axis=-1)
  251. # This is actually dropping out entire tokens to attend to, which might
  252. # seem a bit unusual, but is taken from the original Transformer paper.
  253. attention_probs = self.dropout(inputs=attention_probs, training=training)
  254. # Mask heads if we want to
  255. if head_mask is not None:
  256. attention_probs = tf.multiply(attention_probs, head_mask)
  257. attention_output = tf.matmul(attention_probs, value_layer)
  258. attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
  259. # (batch_size, seq_len_q, all_head_size)
  260. attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
  261. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  262. if self.is_decoder:
  263. outputs = outputs + (past_key_value,)
  264. return outputs
  265. def build(self, input_shape=None):
  266. if self.built:
  267. return
  268. self.built = True
  269. if getattr(self, "query", None) is not None:
  270. with tf.name_scope(self.query.name):
  271. self.query.build([None, None, self.config.hidden_size])
  272. if getattr(self, "key", None) is not None:
  273. with tf.name_scope(self.key.name):
  274. self.key.build([None, None, self.config.hidden_size])
  275. if getattr(self, "value", None) is not None:
  276. with tf.name_scope(self.value.name):
  277. self.value.build([None, None, self.config.hidden_size])
  278. class TFBertSelfOutput(keras.layers.Layer):
  279. def __init__(self, config: BertConfig, **kwargs):
  280. super().__init__(**kwargs)
  281. self.dense = keras.layers.Dense(
  282. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  283. )
  284. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  285. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  286. self.config = config
  287. def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
  288. hidden_states = self.dense(inputs=hidden_states)
  289. hidden_states = self.dropout(inputs=hidden_states, training=training)
  290. hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
  291. return hidden_states
  292. def build(self, input_shape=None):
  293. if self.built:
  294. return
  295. self.built = True
  296. if getattr(self, "dense", None) is not None:
  297. with tf.name_scope(self.dense.name):
  298. self.dense.build([None, None, self.config.hidden_size])
  299. if getattr(self, "LayerNorm", None) is not None:
  300. with tf.name_scope(self.LayerNorm.name):
  301. self.LayerNorm.build([None, None, self.config.hidden_size])
  302. class TFBertAttention(keras.layers.Layer):
  303. def __init__(self, config: BertConfig, **kwargs):
  304. super().__init__(**kwargs)
  305. self.self_attention = TFBertSelfAttention(config, name="self")
  306. self.dense_output = TFBertSelfOutput(config, name="output")
  307. def prune_heads(self, heads):
  308. raise NotImplementedError
  309. def call(
  310. self,
  311. input_tensor: tf.Tensor,
  312. attention_mask: tf.Tensor,
  313. head_mask: tf.Tensor,
  314. encoder_hidden_states: tf.Tensor,
  315. encoder_attention_mask: tf.Tensor,
  316. past_key_value: Tuple[tf.Tensor],
  317. output_attentions: bool,
  318. training: bool = False,
  319. ) -> Tuple[tf.Tensor]:
  320. self_outputs = self.self_attention(
  321. hidden_states=input_tensor,
  322. attention_mask=attention_mask,
  323. head_mask=head_mask,
  324. encoder_hidden_states=encoder_hidden_states,
  325. encoder_attention_mask=encoder_attention_mask,
  326. past_key_value=past_key_value,
  327. output_attentions=output_attentions,
  328. training=training,
  329. )
  330. attention_output = self.dense_output(
  331. hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
  332. )
  333. # add attentions (possibly with past_key_value) if we output them
  334. outputs = (attention_output,) + self_outputs[1:]
  335. return outputs
  336. def build(self, input_shape=None):
  337. if self.built:
  338. return
  339. self.built = True
  340. if getattr(self, "self_attention", None) is not None:
  341. with tf.name_scope(self.self_attention.name):
  342. self.self_attention.build(None)
  343. if getattr(self, "dense_output", None) is not None:
  344. with tf.name_scope(self.dense_output.name):
  345. self.dense_output.build(None)
  346. class TFBertIntermediate(keras.layers.Layer):
  347. def __init__(self, config: BertConfig, **kwargs):
  348. super().__init__(**kwargs)
  349. self.dense = keras.layers.Dense(
  350. units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  351. )
  352. if isinstance(config.hidden_act, str):
  353. self.intermediate_act_fn = get_tf_activation(config.hidden_act)
  354. else:
  355. self.intermediate_act_fn = config.hidden_act
  356. self.config = config
  357. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  358. hidden_states = self.dense(inputs=hidden_states)
  359. hidden_states = self.intermediate_act_fn(hidden_states)
  360. return hidden_states
  361. def build(self, input_shape=None):
  362. if self.built:
  363. return
  364. self.built = True
  365. if getattr(self, "dense", None) is not None:
  366. with tf.name_scope(self.dense.name):
  367. self.dense.build([None, None, self.config.hidden_size])
  368. class TFBertOutput(keras.layers.Layer):
  369. def __init__(self, config: BertConfig, **kwargs):
  370. super().__init__(**kwargs)
  371. self.dense = keras.layers.Dense(
  372. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  373. )
  374. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  375. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  376. self.config = config
  377. def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
  378. hidden_states = self.dense(inputs=hidden_states)
  379. hidden_states = self.dropout(inputs=hidden_states, training=training)
  380. hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
  381. return hidden_states
  382. def build(self, input_shape=None):
  383. if self.built:
  384. return
  385. self.built = True
  386. if getattr(self, "dense", None) is not None:
  387. with tf.name_scope(self.dense.name):
  388. self.dense.build([None, None, self.config.intermediate_size])
  389. if getattr(self, "LayerNorm", None) is not None:
  390. with tf.name_scope(self.LayerNorm.name):
  391. self.LayerNorm.build([None, None, self.config.hidden_size])
  392. class TFBertLayer(keras.layers.Layer):
  393. def __init__(self, config: BertConfig, **kwargs):
  394. super().__init__(**kwargs)
  395. self.attention = TFBertAttention(config, name="attention")
  396. self.is_decoder = config.is_decoder
  397. self.add_cross_attention = config.add_cross_attention
  398. if self.add_cross_attention:
  399. if not self.is_decoder:
  400. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  401. self.crossattention = TFBertAttention(config, name="crossattention")
  402. self.intermediate = TFBertIntermediate(config, name="intermediate")
  403. self.bert_output = TFBertOutput(config, name="output")
  404. def call(
  405. self,
  406. hidden_states: tf.Tensor,
  407. attention_mask: tf.Tensor,
  408. head_mask: tf.Tensor,
  409. encoder_hidden_states: tf.Tensor | None,
  410. encoder_attention_mask: tf.Tensor | None,
  411. past_key_value: Tuple[tf.Tensor] | None,
  412. output_attentions: bool,
  413. training: bool = False,
  414. ) -> Tuple[tf.Tensor]:
  415. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  416. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  417. self_attention_outputs = self.attention(
  418. input_tensor=hidden_states,
  419. attention_mask=attention_mask,
  420. head_mask=head_mask,
  421. encoder_hidden_states=None,
  422. encoder_attention_mask=None,
  423. past_key_value=self_attn_past_key_value,
  424. output_attentions=output_attentions,
  425. training=training,
  426. )
  427. attention_output = self_attention_outputs[0]
  428. # if decoder, the last output is tuple of self-attn cache
  429. if self.is_decoder:
  430. outputs = self_attention_outputs[1:-1]
  431. present_key_value = self_attention_outputs[-1]
  432. else:
  433. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  434. cross_attn_present_key_value = None
  435. if self.is_decoder and encoder_hidden_states is not None:
  436. if not hasattr(self, "crossattention"):
  437. raise ValueError(
  438. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  439. " by setting `config.add_cross_attention=True`"
  440. )
  441. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  442. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  443. cross_attention_outputs = self.crossattention(
  444. input_tensor=attention_output,
  445. attention_mask=attention_mask,
  446. head_mask=head_mask,
  447. encoder_hidden_states=encoder_hidden_states,
  448. encoder_attention_mask=encoder_attention_mask,
  449. past_key_value=cross_attn_past_key_value,
  450. output_attentions=output_attentions,
  451. training=training,
  452. )
  453. attention_output = cross_attention_outputs[0]
  454. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  455. # add cross-attn cache to positions 3,4 of present_key_value tuple
  456. cross_attn_present_key_value = cross_attention_outputs[-1]
  457. present_key_value = present_key_value + cross_attn_present_key_value
  458. intermediate_output = self.intermediate(hidden_states=attention_output)
  459. layer_output = self.bert_output(
  460. hidden_states=intermediate_output, input_tensor=attention_output, training=training
  461. )
  462. outputs = (layer_output,) + outputs # add attentions if we output them
  463. # if decoder, return the attn key/values as the last output
  464. if self.is_decoder:
  465. outputs = outputs + (present_key_value,)
  466. return outputs
  467. def build(self, input_shape=None):
  468. if self.built:
  469. return
  470. self.built = True
  471. if getattr(self, "attention", None) is not None:
  472. with tf.name_scope(self.attention.name):
  473. self.attention.build(None)
  474. if getattr(self, "intermediate", None) is not None:
  475. with tf.name_scope(self.intermediate.name):
  476. self.intermediate.build(None)
  477. if getattr(self, "bert_output", None) is not None:
  478. with tf.name_scope(self.bert_output.name):
  479. self.bert_output.build(None)
  480. if getattr(self, "crossattention", None) is not None:
  481. with tf.name_scope(self.crossattention.name):
  482. self.crossattention.build(None)
  483. class TFBertEncoder(keras.layers.Layer):
  484. def __init__(self, config: BertConfig, **kwargs):
  485. super().__init__(**kwargs)
  486. self.config = config
  487. self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
  488. def call(
  489. self,
  490. hidden_states: tf.Tensor,
  491. attention_mask: tf.Tensor,
  492. head_mask: tf.Tensor,
  493. encoder_hidden_states: tf.Tensor | None,
  494. encoder_attention_mask: tf.Tensor | None,
  495. past_key_values: Tuple[Tuple[tf.Tensor]] | None,
  496. use_cache: Optional[bool],
  497. output_attentions: bool,
  498. output_hidden_states: bool,
  499. return_dict: bool,
  500. training: bool = False,
  501. ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
  502. all_hidden_states = () if output_hidden_states else None
  503. all_attentions = () if output_attentions else None
  504. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  505. next_decoder_cache = () if use_cache else None
  506. for i, layer_module in enumerate(self.layer):
  507. if output_hidden_states:
  508. all_hidden_states = all_hidden_states + (hidden_states,)
  509. past_key_value = past_key_values[i] if past_key_values is not None else None
  510. layer_outputs = layer_module(
  511. hidden_states=hidden_states,
  512. attention_mask=attention_mask,
  513. head_mask=head_mask[i],
  514. encoder_hidden_states=encoder_hidden_states,
  515. encoder_attention_mask=encoder_attention_mask,
  516. past_key_value=past_key_value,
  517. output_attentions=output_attentions,
  518. training=training,
  519. )
  520. hidden_states = layer_outputs[0]
  521. if use_cache:
  522. next_decoder_cache += (layer_outputs[-1],)
  523. if output_attentions:
  524. all_attentions = all_attentions + (layer_outputs[1],)
  525. if self.config.add_cross_attention and encoder_hidden_states is not None:
  526. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  527. # Add last layer
  528. if output_hidden_states:
  529. all_hidden_states = all_hidden_states + (hidden_states,)
  530. if not return_dict:
  531. return tuple(
  532. v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
  533. )
  534. return TFBaseModelOutputWithPastAndCrossAttentions(
  535. last_hidden_state=hidden_states,
  536. past_key_values=next_decoder_cache,
  537. hidden_states=all_hidden_states,
  538. attentions=all_attentions,
  539. cross_attentions=all_cross_attentions,
  540. )
  541. def build(self, input_shape=None):
  542. if self.built:
  543. return
  544. self.built = True
  545. if getattr(self, "layer", None) is not None:
  546. for layer in self.layer:
  547. with tf.name_scope(layer.name):
  548. layer.build(None)
  549. class TFBertPooler(keras.layers.Layer):
  550. def __init__(self, config: BertConfig, **kwargs):
  551. super().__init__(**kwargs)
  552. self.dense = keras.layers.Dense(
  553. units=config.hidden_size,
  554. kernel_initializer=get_initializer(config.initializer_range),
  555. activation="tanh",
  556. name="dense",
  557. )
  558. self.config = config
  559. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  560. # We "pool" the model by simply taking the hidden state corresponding
  561. # to the first token.
  562. first_token_tensor = hidden_states[:, 0]
  563. pooled_output = self.dense(inputs=first_token_tensor)
  564. return pooled_output
  565. def build(self, input_shape=None):
  566. if self.built:
  567. return
  568. self.built = True
  569. if getattr(self, "dense", None) is not None:
  570. with tf.name_scope(self.dense.name):
  571. self.dense.build([None, None, self.config.hidden_size])
  572. class TFBertPredictionHeadTransform(keras.layers.Layer):
  573. def __init__(self, config: BertConfig, **kwargs):
  574. super().__init__(**kwargs)
  575. self.dense = keras.layers.Dense(
  576. units=config.hidden_size,
  577. kernel_initializer=get_initializer(config.initializer_range),
  578. name="dense",
  579. )
  580. if isinstance(config.hidden_act, str):
  581. self.transform_act_fn = get_tf_activation(config.hidden_act)
  582. else:
  583. self.transform_act_fn = config.hidden_act
  584. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  585. self.config = config
  586. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  587. hidden_states = self.dense(inputs=hidden_states)
  588. hidden_states = self.transform_act_fn(hidden_states)
  589. hidden_states = self.LayerNorm(inputs=hidden_states)
  590. return hidden_states
  591. def build(self, input_shape=None):
  592. if self.built:
  593. return
  594. self.built = True
  595. if getattr(self, "dense", None) is not None:
  596. with tf.name_scope(self.dense.name):
  597. self.dense.build([None, None, self.config.hidden_size])
  598. if getattr(self, "LayerNorm", None) is not None:
  599. with tf.name_scope(self.LayerNorm.name):
  600. self.LayerNorm.build([None, None, self.config.hidden_size])
  601. class TFBertLMPredictionHead(keras.layers.Layer):
  602. def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
  603. super().__init__(**kwargs)
  604. self.config = config
  605. self.hidden_size = config.hidden_size
  606. self.transform = TFBertPredictionHeadTransform(config, name="transform")
  607. # The output weights are the same as the input embeddings, but there is
  608. # an output-only bias for each token.
  609. self.input_embeddings = input_embeddings
  610. def build(self, input_shape=None):
  611. self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
  612. if self.built:
  613. return
  614. self.built = True
  615. if getattr(self, "transform", None) is not None:
  616. with tf.name_scope(self.transform.name):
  617. self.transform.build(None)
  618. def get_output_embeddings(self) -> keras.layers.Layer:
  619. return self.input_embeddings
  620. def set_output_embeddings(self, value: tf.Variable):
  621. self.input_embeddings.weight = value
  622. self.input_embeddings.vocab_size = shape_list(value)[0]
  623. def get_bias(self) -> Dict[str, tf.Variable]:
  624. return {"bias": self.bias}
  625. def set_bias(self, value: tf.Variable):
  626. self.bias = value["bias"]
  627. self.config.vocab_size = shape_list(value["bias"])[0]
  628. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  629. hidden_states = self.transform(hidden_states=hidden_states)
  630. seq_length = shape_list(hidden_states)[1]
  631. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
  632. hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
  633. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
  634. hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
  635. return hidden_states
  636. class TFBertMLMHead(keras.layers.Layer):
  637. def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
  638. super().__init__(**kwargs)
  639. self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
  640. def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
  641. prediction_scores = self.predictions(hidden_states=sequence_output)
  642. return prediction_scores
  643. def build(self, input_shape=None):
  644. if self.built:
  645. return
  646. self.built = True
  647. if getattr(self, "predictions", None) is not None:
  648. with tf.name_scope(self.predictions.name):
  649. self.predictions.build(None)
  650. class TFBertNSPHead(keras.layers.Layer):
  651. def __init__(self, config: BertConfig, **kwargs):
  652. super().__init__(**kwargs)
  653. self.seq_relationship = keras.layers.Dense(
  654. units=2,
  655. kernel_initializer=get_initializer(config.initializer_range),
  656. name="seq_relationship",
  657. )
  658. self.config = config
  659. def call(self, pooled_output: tf.Tensor) -> tf.Tensor:
  660. seq_relationship_score = self.seq_relationship(inputs=pooled_output)
  661. return seq_relationship_score
  662. def build(self, input_shape=None):
  663. if self.built:
  664. return
  665. self.built = True
  666. if getattr(self, "seq_relationship", None) is not None:
  667. with tf.name_scope(self.seq_relationship.name):
  668. self.seq_relationship.build([None, None, self.config.hidden_size])
  669. @keras_serializable
  670. class TFBertMainLayer(keras.layers.Layer):
  671. config_class = BertConfig
  672. def __init__(self, config: BertConfig, add_pooling_layer: bool = True, **kwargs):
  673. super().__init__(**kwargs)
  674. self.config = config
  675. self.is_decoder = config.is_decoder
  676. self.embeddings = TFBertEmbeddings(config, name="embeddings")
  677. self.encoder = TFBertEncoder(config, name="encoder")
  678. self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
  679. def get_input_embeddings(self) -> keras.layers.Layer:
  680. return self.embeddings
  681. def set_input_embeddings(self, value: tf.Variable):
  682. self.embeddings.weight = value
  683. self.embeddings.vocab_size = shape_list(value)[0]
  684. def _prune_heads(self, heads_to_prune):
  685. """
  686. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  687. class PreTrainedModel
  688. """
  689. raise NotImplementedError
  690. @unpack_inputs
  691. def call(
  692. self,
  693. input_ids: TFModelInputType | None = None,
  694. attention_mask: np.ndarray | tf.Tensor | None = None,
  695. token_type_ids: np.ndarray | tf.Tensor | None = None,
  696. position_ids: np.ndarray | tf.Tensor | None = None,
  697. head_mask: np.ndarray | tf.Tensor | None = None,
  698. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  699. encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
  700. encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
  701. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  702. use_cache: Optional[bool] = None,
  703. output_attentions: Optional[bool] = None,
  704. output_hidden_states: Optional[bool] = None,
  705. return_dict: Optional[bool] = None,
  706. training: bool = False,
  707. ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
  708. if not self.config.is_decoder:
  709. use_cache = False
  710. if input_ids is not None and inputs_embeds is not None:
  711. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  712. elif input_ids is not None:
  713. input_shape = shape_list(input_ids)
  714. elif inputs_embeds is not None:
  715. input_shape = shape_list(inputs_embeds)[:-1]
  716. else:
  717. raise ValueError("You have to specify either input_ids or inputs_embeds")
  718. batch_size, seq_length = input_shape
  719. if past_key_values is None:
  720. past_key_values_length = 0
  721. past_key_values = [None] * len(self.encoder.layer)
  722. else:
  723. past_key_values_length = shape_list(past_key_values[0][0])[-2]
  724. if attention_mask is None:
  725. attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
  726. if token_type_ids is None:
  727. token_type_ids = tf.fill(dims=input_shape, value=0)
  728. embedding_output = self.embeddings(
  729. input_ids=input_ids,
  730. position_ids=position_ids,
  731. token_type_ids=token_type_ids,
  732. inputs_embeds=inputs_embeds,
  733. past_key_values_length=past_key_values_length,
  734. training=training,
  735. )
  736. # We create a 3D attention mask from a 2D tensor mask.
  737. # Sizes are [batch_size, 1, 1, to_seq_length]
  738. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  739. # this attention mask is more simple than the triangular masking of causal attention
  740. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  741. attention_mask_shape = shape_list(attention_mask)
  742. mask_seq_length = seq_length + past_key_values_length
  743. # Copied from `modeling_tf_t5.py`
  744. # Provided a padding mask of dimensions [batch_size, mask_seq_length]
  745. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  746. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
  747. if self.is_decoder:
  748. seq_ids = tf.range(mask_seq_length)
  749. causal_mask = tf.less_equal(
  750. tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
  751. seq_ids[None, :, None],
  752. )
  753. causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
  754. extended_attention_mask = causal_mask * attention_mask[:, None, :]
  755. attention_mask_shape = shape_list(extended_attention_mask)
  756. extended_attention_mask = tf.reshape(
  757. extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
  758. )
  759. if past_key_values[0] is not None:
  760. # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
  761. extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
  762. else:
  763. extended_attention_mask = tf.reshape(
  764. attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
  765. )
  766. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  767. # masked positions, this operation will create a tensor which is 0.0 for
  768. # positions we want to attend and -10000.0 for masked positions.
  769. # Since we are adding it to the raw scores before the softmax, this is
  770. # effectively the same as removing these entirely.
  771. extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
  772. one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
  773. ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
  774. extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
  775. # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
  776. if self.is_decoder and encoder_attention_mask is not None:
  777. # If a 2D ou 3D attention mask is provided for the cross-attention
  778. # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
  779. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  780. encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
  781. num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
  782. if num_dims_encoder_attention_mask == 3:
  783. encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
  784. if num_dims_encoder_attention_mask == 2:
  785. encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
  786. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
  787. # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
  788. # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
  789. # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
  790. encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
  791. else:
  792. encoder_extended_attention_mask = None
  793. # Prepare head mask if needed
  794. # 1.0 in head_mask indicate we keep the head
  795. # attention_probs has shape bsz x n_heads x N x N
  796. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  797. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  798. if head_mask is not None:
  799. raise NotImplementedError
  800. else:
  801. head_mask = [None] * self.config.num_hidden_layers
  802. encoder_outputs = self.encoder(
  803. hidden_states=embedding_output,
  804. attention_mask=extended_attention_mask,
  805. head_mask=head_mask,
  806. encoder_hidden_states=encoder_hidden_states,
  807. encoder_attention_mask=encoder_extended_attention_mask,
  808. past_key_values=past_key_values,
  809. use_cache=use_cache,
  810. output_attentions=output_attentions,
  811. output_hidden_states=output_hidden_states,
  812. return_dict=return_dict,
  813. training=training,
  814. )
  815. sequence_output = encoder_outputs[0]
  816. pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
  817. if not return_dict:
  818. return (
  819. sequence_output,
  820. pooled_output,
  821. ) + encoder_outputs[1:]
  822. return TFBaseModelOutputWithPoolingAndCrossAttentions(
  823. last_hidden_state=sequence_output,
  824. pooler_output=pooled_output,
  825. past_key_values=encoder_outputs.past_key_values,
  826. hidden_states=encoder_outputs.hidden_states,
  827. attentions=encoder_outputs.attentions,
  828. cross_attentions=encoder_outputs.cross_attentions,
  829. )
  830. def build(self, input_shape=None):
  831. if self.built:
  832. return
  833. self.built = True
  834. if getattr(self, "embeddings", None) is not None:
  835. with tf.name_scope(self.embeddings.name):
  836. self.embeddings.build(None)
  837. if getattr(self, "encoder", None) is not None:
  838. with tf.name_scope(self.encoder.name):
  839. self.encoder.build(None)
  840. if getattr(self, "pooler", None) is not None:
  841. with tf.name_scope(self.pooler.name):
  842. self.pooler.build(None)
  843. class TFBertPreTrainedModel(TFPreTrainedModel):
  844. """
  845. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  846. models.
  847. """
  848. config_class = BertConfig
  849. base_model_prefix = "bert"
  850. @dataclass
  851. class TFBertForPreTrainingOutput(ModelOutput):
  852. """
  853. Output type of [`TFBertForPreTraining`].
  854. Args:
  855. prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  856. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  857. seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`):
  858. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  859. before SoftMax).
  860. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  861. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  862. `(batch_size, sequence_length, hidden_size)`.
  863. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  864. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  865. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  866. sequence_length)`.
  867. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  868. heads.
  869. """
  870. loss: tf.Tensor | None = None
  871. prediction_logits: tf.Tensor = None
  872. seq_relationship_logits: tf.Tensor = None
  873. hidden_states: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None
  874. attentions: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None
  875. BERT_START_DOCSTRING = r"""
  876. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  877. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  878. etc.)
  879. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  880. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  881. behavior.
  882. <Tip>
  883. TensorFlow models and layers in `transformers` accept two formats as input:
  884. - having all inputs as keyword arguments (like PyTorch models), or
  885. - having all inputs as a list, tuple or dict in the first positional argument.
  886. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  887. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  888. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  889. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  890. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  891. positional argument:
  892. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  893. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  894. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  895. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  896. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  897. Note that when creating models and layers with
  898. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  899. about any of this, as you can just pass inputs like you would to any other Python function!
  900. </Tip>
  901. Args:
  902. config ([`BertConfig`]): Model configuration class with all the parameters of the model.
  903. Initializing with a config file does not load the weights associated with the model, only the
  904. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  905. """
  906. BERT_INPUTS_DOCSTRING = r"""
  907. Args:
  908. input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
  909. Indices of input sequence tokens in the vocabulary.
  910. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  911. [`PreTrainedTokenizer.encode`] for details.
  912. [What are input IDs?](../glossary#input-ids)
  913. attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  914. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  915. - 1 for tokens that are **not masked**,
  916. - 0 for tokens that are **masked**.
  917. [What are attention masks?](../glossary#attention-mask)
  918. token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  919. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  920. 1]`:
  921. - 0 corresponds to a *sentence A* token,
  922. - 1 corresponds to a *sentence B* token.
  923. [What are token type IDs?](../glossary#token-type-ids)
  924. position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  925. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  926. config.max_position_embeddings - 1]`.
  927. [What are position IDs?](../glossary#position-ids)
  928. head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  929. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  930. - 1 indicates the head is **not masked**,
  931. - 0 indicates the head is **masked**.
  932. inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
  933. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  934. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  935. model's internal embedding lookup matrix.
  936. output_attentions (`bool`, *optional*):
  937. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  938. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  939. config will be used instead.
  940. output_hidden_states (`bool`, *optional*):
  941. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  942. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  943. used instead.
  944. return_dict (`bool`, *optional*):
  945. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  946. eager mode, in graph mode the value will always be set to True.
  947. training (`bool`, *optional*, defaults to `False``):
  948. Whether or not to use the model in training mode (some modules like dropout modules have different
  949. behaviors between training and evaluation).
  950. """
  951. @add_start_docstrings(
  952. "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
  953. BERT_START_DOCSTRING,
  954. )
  955. class TFBertModel(TFBertPreTrainedModel):
  956. def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
  957. super().__init__(config, *inputs, **kwargs)
  958. self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")
  959. @unpack_inputs
  960. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  961. @add_code_sample_docstrings(
  962. checkpoint=_CHECKPOINT_FOR_DOC,
  963. output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
  964. config_class=_CONFIG_FOR_DOC,
  965. )
  966. def call(
  967. self,
  968. input_ids: TFModelInputType | None = None,
  969. attention_mask: np.ndarray | tf.Tensor | None = None,
  970. token_type_ids: np.ndarray | tf.Tensor | None = None,
  971. position_ids: np.ndarray | tf.Tensor | None = None,
  972. head_mask: np.ndarray | tf.Tensor | None = None,
  973. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  974. encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
  975. encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
  976. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  977. use_cache: Optional[bool] = None,
  978. output_attentions: Optional[bool] = None,
  979. output_hidden_states: Optional[bool] = None,
  980. return_dict: Optional[bool] = None,
  981. training: Optional[bool] = False,
  982. ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
  983. r"""
  984. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  985. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  986. the model is configured as a decoder.
  987. encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  988. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  989. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  990. - 1 for tokens that are **not masked**,
  991. - 0 for tokens that are **masked**.
  992. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
  993. contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  994. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  995. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  996. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  997. use_cache (`bool`, *optional*, defaults to `True`):
  998. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  999. `past_key_values`). Set to `False` during training, `True` during generation
  1000. """
  1001. outputs = self.bert(
  1002. input_ids=input_ids,
  1003. attention_mask=attention_mask,
  1004. token_type_ids=token_type_ids,
  1005. position_ids=position_ids,
  1006. head_mask=head_mask,
  1007. inputs_embeds=inputs_embeds,
  1008. encoder_hidden_states=encoder_hidden_states,
  1009. encoder_attention_mask=encoder_attention_mask,
  1010. past_key_values=past_key_values,
  1011. use_cache=use_cache,
  1012. output_attentions=output_attentions,
  1013. output_hidden_states=output_hidden_states,
  1014. return_dict=return_dict,
  1015. training=training,
  1016. )
  1017. return outputs
  1018. def build(self, input_shape=None):
  1019. if self.built:
  1020. return
  1021. self.built = True
  1022. if getattr(self, "bert", None) is not None:
  1023. with tf.name_scope(self.bert.name):
  1024. self.bert.build(None)
  1025. @add_start_docstrings(
  1026. """
  1027. Bert Model with two heads on top as done during the pretraining:
  1028. a `masked language modeling` head and a `next sentence prediction (classification)` head.
  1029. """,
  1030. BERT_START_DOCSTRING,
  1031. )
  1032. class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
  1033. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1034. _keys_to_ignore_on_load_unexpected = [
  1035. r"position_ids",
  1036. r"cls.predictions.decoder.weight",
  1037. r"cls.predictions.decoder.bias",
  1038. ]
  1039. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1040. super().__init__(config, *inputs, **kwargs)
  1041. self.bert = TFBertMainLayer(config, name="bert")
  1042. self.nsp = TFBertNSPHead(config, name="nsp___cls")
  1043. self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
  1044. def get_lm_head(self) -> keras.layers.Layer:
  1045. return self.mlm.predictions
  1046. def get_prefix_bias_name(self) -> str:
  1047. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1048. return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
  1049. @unpack_inputs
  1050. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1051. @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  1052. def call(
  1053. self,
  1054. input_ids: TFModelInputType | None = None,
  1055. attention_mask: np.ndarray | tf.Tensor | None = None,
  1056. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1057. position_ids: np.ndarray | tf.Tensor | None = None,
  1058. head_mask: np.ndarray | tf.Tensor | None = None,
  1059. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1060. output_attentions: Optional[bool] = None,
  1061. output_hidden_states: Optional[bool] = None,
  1062. return_dict: Optional[bool] = None,
  1063. labels: np.ndarray | tf.Tensor | None = None,
  1064. next_sentence_label: np.ndarray | tf.Tensor | None = None,
  1065. training: Optional[bool] = False,
  1066. ) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]:
  1067. r"""
  1068. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1069. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1070. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1071. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1072. next_sentence_label (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1073. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  1074. (see `input_ids` docstring) Indices should be in `[0, 1]`:
  1075. - 0 indicates sequence B is a continuation of sequence A,
  1076. - 1 indicates sequence B is a random sequence.
  1077. kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
  1078. Used to hide legacy arguments that have been deprecated.
  1079. Return:
  1080. Examples:
  1081. ```python
  1082. >>> import tensorflow as tf
  1083. >>> from transformers import AutoTokenizer, TFBertForPreTraining
  1084. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1085. >>> model = TFBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
  1086. >>> input_ids = tokenizer("Hello, my dog is cute", add_special_tokens=True, return_tensors="tf")
  1087. >>> # Batch size 1
  1088. >>> outputs = model(input_ids)
  1089. >>> prediction_logits, seq_relationship_logits = outputs[:2]
  1090. ```"""
  1091. outputs = self.bert(
  1092. input_ids=input_ids,
  1093. attention_mask=attention_mask,
  1094. token_type_ids=token_type_ids,
  1095. position_ids=position_ids,
  1096. head_mask=head_mask,
  1097. inputs_embeds=inputs_embeds,
  1098. output_attentions=output_attentions,
  1099. output_hidden_states=output_hidden_states,
  1100. return_dict=return_dict,
  1101. training=training,
  1102. )
  1103. sequence_output, pooled_output = outputs[:2]
  1104. prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
  1105. seq_relationship_score = self.nsp(pooled_output=pooled_output)
  1106. total_loss = None
  1107. if labels is not None and next_sentence_label is not None:
  1108. d_labels = {"labels": labels}
  1109. d_labels["next_sentence_label"] = next_sentence_label
  1110. total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
  1111. if not return_dict:
  1112. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  1113. return ((total_loss,) + output) if total_loss is not None else output
  1114. return TFBertForPreTrainingOutput(
  1115. loss=total_loss,
  1116. prediction_logits=prediction_scores,
  1117. seq_relationship_logits=seq_relationship_score,
  1118. hidden_states=outputs.hidden_states,
  1119. attentions=outputs.attentions,
  1120. )
  1121. def build(self, input_shape=None):
  1122. if self.built:
  1123. return
  1124. self.built = True
  1125. if getattr(self, "bert", None) is not None:
  1126. with tf.name_scope(self.bert.name):
  1127. self.bert.build(None)
  1128. if getattr(self, "nsp", None) is not None:
  1129. with tf.name_scope(self.nsp.name):
  1130. self.nsp.build(None)
  1131. if getattr(self, "mlm", None) is not None:
  1132. with tf.name_scope(self.mlm.name):
  1133. self.mlm.build(None)
  1134. @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
  1135. class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
  1136. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1137. _keys_to_ignore_on_load_unexpected = [
  1138. r"pooler",
  1139. r"cls.seq_relationship",
  1140. r"cls.predictions.decoder.weight",
  1141. r"nsp___cls",
  1142. ]
  1143. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1144. super().__init__(config, *inputs, **kwargs)
  1145. if config.is_decoder:
  1146. logger.warning(
  1147. "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
  1148. "bi-directional self-attention."
  1149. )
  1150. self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
  1151. self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
  1152. def get_lm_head(self) -> keras.layers.Layer:
  1153. return self.mlm.predictions
  1154. def get_prefix_bias_name(self) -> str:
  1155. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1156. return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
  1157. @unpack_inputs
  1158. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1159. @add_code_sample_docstrings(
  1160. checkpoint=_CHECKPOINT_FOR_DOC,
  1161. output_type=TFMaskedLMOutput,
  1162. config_class=_CONFIG_FOR_DOC,
  1163. expected_output="'paris'",
  1164. expected_loss=0.88,
  1165. )
  1166. def call(
  1167. self,
  1168. input_ids: TFModelInputType | None = None,
  1169. attention_mask: np.ndarray | tf.Tensor | None = None,
  1170. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1171. position_ids: np.ndarray | tf.Tensor | None = None,
  1172. head_mask: np.ndarray | tf.Tensor | None = None,
  1173. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1174. output_attentions: Optional[bool] = None,
  1175. output_hidden_states: Optional[bool] = None,
  1176. return_dict: Optional[bool] = None,
  1177. labels: np.ndarray | tf.Tensor | None = None,
  1178. training: Optional[bool] = False,
  1179. ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
  1180. r"""
  1181. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1182. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1183. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1184. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1185. """
  1186. outputs = self.bert(
  1187. input_ids=input_ids,
  1188. attention_mask=attention_mask,
  1189. token_type_ids=token_type_ids,
  1190. position_ids=position_ids,
  1191. head_mask=head_mask,
  1192. inputs_embeds=inputs_embeds,
  1193. output_attentions=output_attentions,
  1194. output_hidden_states=output_hidden_states,
  1195. return_dict=return_dict,
  1196. training=training,
  1197. )
  1198. sequence_output = outputs[0]
  1199. prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
  1200. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
  1201. if not return_dict:
  1202. output = (prediction_scores,) + outputs[2:]
  1203. return ((loss,) + output) if loss is not None else output
  1204. return TFMaskedLMOutput(
  1205. loss=loss,
  1206. logits=prediction_scores,
  1207. hidden_states=outputs.hidden_states,
  1208. attentions=outputs.attentions,
  1209. )
  1210. def build(self, input_shape=None):
  1211. if self.built:
  1212. return
  1213. self.built = True
  1214. if getattr(self, "bert", None) is not None:
  1215. with tf.name_scope(self.bert.name):
  1216. self.bert.build(None)
  1217. if getattr(self, "mlm", None) is not None:
  1218. with tf.name_scope(self.mlm.name):
  1219. self.mlm.build(None)
  1220. class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
  1221. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1222. _keys_to_ignore_on_load_unexpected = [
  1223. r"pooler",
  1224. r"cls.seq_relationship",
  1225. r"cls.predictions.decoder.weight",
  1226. r"nsp___cls",
  1227. ]
  1228. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1229. super().__init__(config, *inputs, **kwargs)
  1230. if not config.is_decoder:
  1231. logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
  1232. self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
  1233. self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
  1234. def get_lm_head(self) -> keras.layers.Layer:
  1235. return self.mlm.predictions
  1236. def get_prefix_bias_name(self) -> str:
  1237. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1238. return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
  1239. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
  1240. input_shape = input_ids.shape
  1241. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1242. if attention_mask is None:
  1243. attention_mask = tf.ones(input_shape)
  1244. # cut decoder_input_ids if past is used
  1245. if past_key_values is not None:
  1246. input_ids = input_ids[:, -1:]
  1247. return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
  1248. @unpack_inputs
  1249. @add_code_sample_docstrings(
  1250. checkpoint=_CHECKPOINT_FOR_DOC,
  1251. output_type=TFCausalLMOutputWithCrossAttentions,
  1252. config_class=_CONFIG_FOR_DOC,
  1253. )
  1254. def call(
  1255. self,
  1256. input_ids: TFModelInputType | None = None,
  1257. attention_mask: np.ndarray | tf.Tensor | None = None,
  1258. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1259. position_ids: np.ndarray | tf.Tensor | None = None,
  1260. head_mask: np.ndarray | tf.Tensor | None = None,
  1261. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1262. encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
  1263. encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
  1264. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  1265. use_cache: Optional[bool] = None,
  1266. output_attentions: Optional[bool] = None,
  1267. output_hidden_states: Optional[bool] = None,
  1268. return_dict: Optional[bool] = None,
  1269. labels: np.ndarray | tf.Tensor | None = None,
  1270. training: Optional[bool] = False,
  1271. **kwargs,
  1272. ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
  1273. r"""
  1274. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1275. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  1276. the model is configured as a decoder.
  1277. encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1278. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  1279. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1280. - 1 for tokens that are **not masked**,
  1281. - 0 for tokens that are **masked**.
  1282. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
  1283. contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1284. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1285. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1286. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1287. use_cache (`bool`, *optional*, defaults to `True`):
  1288. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1289. `past_key_values`). Set to `False` during training, `True` during generation
  1290. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1291. Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
  1292. config.vocab_size - 1]`.
  1293. """
  1294. outputs = self.bert(
  1295. input_ids=input_ids,
  1296. attention_mask=attention_mask,
  1297. token_type_ids=token_type_ids,
  1298. position_ids=position_ids,
  1299. head_mask=head_mask,
  1300. inputs_embeds=inputs_embeds,
  1301. encoder_hidden_states=encoder_hidden_states,
  1302. encoder_attention_mask=encoder_attention_mask,
  1303. past_key_values=past_key_values,
  1304. use_cache=use_cache,
  1305. output_attentions=output_attentions,
  1306. output_hidden_states=output_hidden_states,
  1307. return_dict=return_dict,
  1308. training=training,
  1309. )
  1310. sequence_output = outputs[0]
  1311. logits = self.mlm(sequence_output=sequence_output, training=training)
  1312. loss = None
  1313. if labels is not None:
  1314. # shift labels to the left and cut last logit token
  1315. shifted_logits = logits[:, :-1]
  1316. labels = labels[:, 1:]
  1317. loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
  1318. if not return_dict:
  1319. output = (logits,) + outputs[2:]
  1320. return ((loss,) + output) if loss is not None else output
  1321. return TFCausalLMOutputWithCrossAttentions(
  1322. loss=loss,
  1323. logits=logits,
  1324. past_key_values=outputs.past_key_values,
  1325. hidden_states=outputs.hidden_states,
  1326. attentions=outputs.attentions,
  1327. cross_attentions=outputs.cross_attentions,
  1328. )
  1329. def build(self, input_shape=None):
  1330. if self.built:
  1331. return
  1332. self.built = True
  1333. if getattr(self, "bert", None) is not None:
  1334. with tf.name_scope(self.bert.name):
  1335. self.bert.build(None)
  1336. if getattr(self, "mlm", None) is not None:
  1337. with tf.name_scope(self.mlm.name):
  1338. self.mlm.build(None)
  1339. @add_start_docstrings(
  1340. """Bert Model with a `next sentence prediction (classification)` head on top.""",
  1341. BERT_START_DOCSTRING,
  1342. )
  1343. class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
  1344. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1345. _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]
  1346. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1347. super().__init__(config, *inputs, **kwargs)
  1348. self.bert = TFBertMainLayer(config, name="bert")
  1349. self.nsp = TFBertNSPHead(config, name="nsp___cls")
  1350. @unpack_inputs
  1351. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1352. @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
  1353. def call(
  1354. self,
  1355. input_ids: TFModelInputType | None = None,
  1356. attention_mask: np.ndarray | tf.Tensor | None = None,
  1357. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1358. position_ids: np.ndarray | tf.Tensor | None = None,
  1359. head_mask: np.ndarray | tf.Tensor | None = None,
  1360. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1361. output_attentions: Optional[bool] = None,
  1362. output_hidden_states: Optional[bool] = None,
  1363. return_dict: Optional[bool] = None,
  1364. next_sentence_label: np.ndarray | tf.Tensor | None = None,
  1365. training: Optional[bool] = False,
  1366. ) -> Union[TFNextSentencePredictorOutput, Tuple[tf.Tensor]]:
  1367. r"""
  1368. Return:
  1369. Examples:
  1370. ```python
  1371. >>> import tensorflow as tf
  1372. >>> from transformers import AutoTokenizer, TFBertForNextSentencePrediction
  1373. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1374. >>> model = TFBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
  1375. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1376. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  1377. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf")
  1378. >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0]
  1379. >>> assert logits[0][0] < logits[0][1] # the next sentence was random
  1380. ```"""
  1381. outputs = self.bert(
  1382. input_ids=input_ids,
  1383. attention_mask=attention_mask,
  1384. token_type_ids=token_type_ids,
  1385. position_ids=position_ids,
  1386. head_mask=head_mask,
  1387. inputs_embeds=inputs_embeds,
  1388. output_attentions=output_attentions,
  1389. output_hidden_states=output_hidden_states,
  1390. return_dict=return_dict,
  1391. training=training,
  1392. )
  1393. pooled_output = outputs[1]
  1394. seq_relationship_scores = self.nsp(pooled_output=pooled_output)
  1395. next_sentence_loss = (
  1396. None
  1397. if next_sentence_label is None
  1398. else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
  1399. )
  1400. if not return_dict:
  1401. output = (seq_relationship_scores,) + outputs[2:]
  1402. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  1403. return TFNextSentencePredictorOutput(
  1404. loss=next_sentence_loss,
  1405. logits=seq_relationship_scores,
  1406. hidden_states=outputs.hidden_states,
  1407. attentions=outputs.attentions,
  1408. )
  1409. def build(self, input_shape=None):
  1410. if self.built:
  1411. return
  1412. self.built = True
  1413. if getattr(self, "bert", None) is not None:
  1414. with tf.name_scope(self.bert.name):
  1415. self.bert.build(None)
  1416. if getattr(self, "nsp", None) is not None:
  1417. with tf.name_scope(self.nsp.name):
  1418. self.nsp.build(None)
  1419. @add_start_docstrings(
  1420. """
  1421. Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  1422. output) e.g. for GLUE tasks.
  1423. """,
  1424. BERT_START_DOCSTRING,
  1425. )
  1426. class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
  1427. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1428. _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
  1429. _keys_to_ignore_on_load_missing = [r"dropout"]
  1430. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1431. super().__init__(config, *inputs, **kwargs)
  1432. self.num_labels = config.num_labels
  1433. self.bert = TFBertMainLayer(config, name="bert")
  1434. classifier_dropout = (
  1435. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1436. )
  1437. self.dropout = keras.layers.Dropout(rate=classifier_dropout)
  1438. self.classifier = keras.layers.Dense(
  1439. units=config.num_labels,
  1440. kernel_initializer=get_initializer(config.initializer_range),
  1441. name="classifier",
  1442. )
  1443. self.config = config
  1444. @unpack_inputs
  1445. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1446. @add_code_sample_docstrings(
  1447. checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
  1448. output_type=TFSequenceClassifierOutput,
  1449. config_class=_CONFIG_FOR_DOC,
  1450. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1451. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1452. )
  1453. def call(
  1454. self,
  1455. input_ids: TFModelInputType | None = None,
  1456. attention_mask: np.ndarray | tf.Tensor | None = None,
  1457. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1458. position_ids: np.ndarray | tf.Tensor | None = None,
  1459. head_mask: np.ndarray | tf.Tensor | None = None,
  1460. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1461. output_attentions: Optional[bool] = None,
  1462. output_hidden_states: Optional[bool] = None,
  1463. return_dict: Optional[bool] = None,
  1464. labels: np.ndarray | tf.Tensor | None = None,
  1465. training: Optional[bool] = False,
  1466. ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
  1467. r"""
  1468. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  1469. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1470. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1471. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1472. """
  1473. outputs = self.bert(
  1474. input_ids=input_ids,
  1475. attention_mask=attention_mask,
  1476. token_type_ids=token_type_ids,
  1477. position_ids=position_ids,
  1478. head_mask=head_mask,
  1479. inputs_embeds=inputs_embeds,
  1480. output_attentions=output_attentions,
  1481. output_hidden_states=output_hidden_states,
  1482. return_dict=return_dict,
  1483. training=training,
  1484. )
  1485. pooled_output = outputs[1]
  1486. pooled_output = self.dropout(inputs=pooled_output, training=training)
  1487. logits = self.classifier(inputs=pooled_output)
  1488. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  1489. if not return_dict:
  1490. output = (logits,) + outputs[2:]
  1491. return ((loss,) + output) if loss is not None else output
  1492. return TFSequenceClassifierOutput(
  1493. loss=loss,
  1494. logits=logits,
  1495. hidden_states=outputs.hidden_states,
  1496. attentions=outputs.attentions,
  1497. )
  1498. def build(self, input_shape=None):
  1499. if self.built:
  1500. return
  1501. self.built = True
  1502. if getattr(self, "bert", None) is not None:
  1503. with tf.name_scope(self.bert.name):
  1504. self.bert.build(None)
  1505. if getattr(self, "classifier", None) is not None:
  1506. with tf.name_scope(self.classifier.name):
  1507. self.classifier.build([None, None, self.config.hidden_size])
  1508. @add_start_docstrings(
  1509. """
  1510. Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1511. softmax) e.g. for RocStories/SWAG tasks.
  1512. """,
  1513. BERT_START_DOCSTRING,
  1514. )
  1515. class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
  1516. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1517. _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
  1518. _keys_to_ignore_on_load_missing = [r"dropout"]
  1519. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1520. super().__init__(config, *inputs, **kwargs)
  1521. self.bert = TFBertMainLayer(config, name="bert")
  1522. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  1523. self.classifier = keras.layers.Dense(
  1524. units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1525. )
  1526. self.config = config
  1527. @unpack_inputs
  1528. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
  1529. @add_code_sample_docstrings(
  1530. checkpoint=_CHECKPOINT_FOR_DOC,
  1531. output_type=TFMultipleChoiceModelOutput,
  1532. config_class=_CONFIG_FOR_DOC,
  1533. )
  1534. def call(
  1535. self,
  1536. input_ids: TFModelInputType | None = None,
  1537. attention_mask: np.ndarray | tf.Tensor | None = None,
  1538. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1539. position_ids: np.ndarray | tf.Tensor | None = None,
  1540. head_mask: np.ndarray | tf.Tensor | None = None,
  1541. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1542. output_attentions: Optional[bool] = None,
  1543. output_hidden_states: Optional[bool] = None,
  1544. return_dict: Optional[bool] = None,
  1545. labels: np.ndarray | tf.Tensor | None = None,
  1546. training: Optional[bool] = False,
  1547. ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
  1548. r"""
  1549. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  1550. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  1551. where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
  1552. """
  1553. if input_ids is not None:
  1554. num_choices = shape_list(input_ids)[1]
  1555. seq_length = shape_list(input_ids)[2]
  1556. else:
  1557. num_choices = shape_list(inputs_embeds)[1]
  1558. seq_length = shape_list(inputs_embeds)[2]
  1559. flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
  1560. flat_attention_mask = (
  1561. tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
  1562. )
  1563. flat_token_type_ids = (
  1564. tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
  1565. )
  1566. flat_position_ids = (
  1567. tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
  1568. )
  1569. flat_inputs_embeds = (
  1570. tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
  1571. if inputs_embeds is not None
  1572. else None
  1573. )
  1574. outputs = self.bert(
  1575. input_ids=flat_input_ids,
  1576. attention_mask=flat_attention_mask,
  1577. token_type_ids=flat_token_type_ids,
  1578. position_ids=flat_position_ids,
  1579. head_mask=head_mask,
  1580. inputs_embeds=flat_inputs_embeds,
  1581. output_attentions=output_attentions,
  1582. output_hidden_states=output_hidden_states,
  1583. return_dict=return_dict,
  1584. training=training,
  1585. )
  1586. pooled_output = outputs[1]
  1587. pooled_output = self.dropout(inputs=pooled_output, training=training)
  1588. logits = self.classifier(inputs=pooled_output)
  1589. reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
  1590. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
  1591. if not return_dict:
  1592. output = (reshaped_logits,) + outputs[2:]
  1593. return ((loss,) + output) if loss is not None else output
  1594. return TFMultipleChoiceModelOutput(
  1595. loss=loss,
  1596. logits=reshaped_logits,
  1597. hidden_states=outputs.hidden_states,
  1598. attentions=outputs.attentions,
  1599. )
  1600. def build(self, input_shape=None):
  1601. if self.built:
  1602. return
  1603. self.built = True
  1604. if getattr(self, "bert", None) is not None:
  1605. with tf.name_scope(self.bert.name):
  1606. self.bert.build(None)
  1607. if getattr(self, "classifier", None) is not None:
  1608. with tf.name_scope(self.classifier.name):
  1609. self.classifier.build([None, None, self.config.hidden_size])
  1610. @add_start_docstrings(
  1611. """
  1612. Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1613. Named-Entity-Recognition (NER) tasks.
  1614. """,
  1615. BERT_START_DOCSTRING,
  1616. )
  1617. class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
  1618. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1619. _keys_to_ignore_on_load_unexpected = [
  1620. r"pooler",
  1621. r"mlm___cls",
  1622. r"nsp___cls",
  1623. r"cls.predictions",
  1624. r"cls.seq_relationship",
  1625. ]
  1626. _keys_to_ignore_on_load_missing = [r"dropout"]
  1627. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1628. super().__init__(config, *inputs, **kwargs)
  1629. self.num_labels = config.num_labels
  1630. self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
  1631. classifier_dropout = (
  1632. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1633. )
  1634. self.dropout = keras.layers.Dropout(rate=classifier_dropout)
  1635. self.classifier = keras.layers.Dense(
  1636. units=config.num_labels,
  1637. kernel_initializer=get_initializer(config.initializer_range),
  1638. name="classifier",
  1639. )
  1640. self.config = config
  1641. @unpack_inputs
  1642. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1643. @add_code_sample_docstrings(
  1644. checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
  1645. output_type=TFTokenClassifierOutput,
  1646. config_class=_CONFIG_FOR_DOC,
  1647. expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
  1648. expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
  1649. )
  1650. def call(
  1651. self,
  1652. input_ids: TFModelInputType | None = None,
  1653. attention_mask: np.ndarray | tf.Tensor | None = None,
  1654. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1655. position_ids: np.ndarray | tf.Tensor | None = None,
  1656. head_mask: np.ndarray | tf.Tensor | None = None,
  1657. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1658. output_attentions: Optional[bool] = None,
  1659. output_hidden_states: Optional[bool] = None,
  1660. return_dict: Optional[bool] = None,
  1661. labels: np.ndarray | tf.Tensor | None = None,
  1662. training: Optional[bool] = False,
  1663. ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
  1664. r"""
  1665. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1666. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1667. """
  1668. outputs = self.bert(
  1669. input_ids=input_ids,
  1670. attention_mask=attention_mask,
  1671. token_type_ids=token_type_ids,
  1672. position_ids=position_ids,
  1673. head_mask=head_mask,
  1674. inputs_embeds=inputs_embeds,
  1675. output_attentions=output_attentions,
  1676. output_hidden_states=output_hidden_states,
  1677. return_dict=return_dict,
  1678. training=training,
  1679. )
  1680. sequence_output = outputs[0]
  1681. sequence_output = self.dropout(inputs=sequence_output, training=training)
  1682. logits = self.classifier(inputs=sequence_output)
  1683. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  1684. if not return_dict:
  1685. output = (logits,) + outputs[2:]
  1686. return ((loss,) + output) if loss is not None else output
  1687. return TFTokenClassifierOutput(
  1688. loss=loss,
  1689. logits=logits,
  1690. hidden_states=outputs.hidden_states,
  1691. attentions=outputs.attentions,
  1692. )
  1693. def build(self, input_shape=None):
  1694. if self.built:
  1695. return
  1696. self.built = True
  1697. if getattr(self, "bert", None) is not None:
  1698. with tf.name_scope(self.bert.name):
  1699. self.bert.build(None)
  1700. if getattr(self, "classifier", None) is not None:
  1701. with tf.name_scope(self.classifier.name):
  1702. self.classifier.build([None, None, self.config.hidden_size])
  1703. @add_start_docstrings(
  1704. """
  1705. Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1706. layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1707. """,
  1708. BERT_START_DOCSTRING,
  1709. )
  1710. class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
  1711. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1712. _keys_to_ignore_on_load_unexpected = [
  1713. r"pooler",
  1714. r"mlm___cls",
  1715. r"nsp___cls",
  1716. r"cls.predictions",
  1717. r"cls.seq_relationship",
  1718. ]
  1719. def __init__(self, config: BertConfig, *inputs, **kwargs):
  1720. super().__init__(config, *inputs, **kwargs)
  1721. self.num_labels = config.num_labels
  1722. self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
  1723. self.qa_outputs = keras.layers.Dense(
  1724. units=config.num_labels,
  1725. kernel_initializer=get_initializer(config.initializer_range),
  1726. name="qa_outputs",
  1727. )
  1728. self.config = config
  1729. @unpack_inputs
  1730. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1731. @add_code_sample_docstrings(
  1732. checkpoint=_CHECKPOINT_FOR_QA,
  1733. output_type=TFQuestionAnsweringModelOutput,
  1734. config_class=_CONFIG_FOR_DOC,
  1735. qa_target_start_index=_QA_TARGET_START_INDEX,
  1736. qa_target_end_index=_QA_TARGET_END_INDEX,
  1737. expected_output=_QA_EXPECTED_OUTPUT,
  1738. expected_loss=_QA_EXPECTED_LOSS,
  1739. )
  1740. def call(
  1741. self,
  1742. input_ids: TFModelInputType | None = None,
  1743. attention_mask: np.ndarray | tf.Tensor | None = None,
  1744. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1745. position_ids: np.ndarray | tf.Tensor | None = None,
  1746. head_mask: np.ndarray | tf.Tensor | None = None,
  1747. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1748. output_attentions: Optional[bool] = None,
  1749. output_hidden_states: Optional[bool] = None,
  1750. return_dict: Optional[bool] = None,
  1751. start_positions: np.ndarray | tf.Tensor | None = None,
  1752. end_positions: np.ndarray | tf.Tensor | None = None,
  1753. training: Optional[bool] = False,
  1754. ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
  1755. r"""
  1756. start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  1757. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1758. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1759. are not taken into account for computing the loss.
  1760. end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  1761. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1762. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1763. are not taken into account for computing the loss.
  1764. """
  1765. outputs = self.bert(
  1766. input_ids=input_ids,
  1767. attention_mask=attention_mask,
  1768. token_type_ids=token_type_ids,
  1769. position_ids=position_ids,
  1770. head_mask=head_mask,
  1771. inputs_embeds=inputs_embeds,
  1772. output_attentions=output_attentions,
  1773. output_hidden_states=output_hidden_states,
  1774. return_dict=return_dict,
  1775. training=training,
  1776. )
  1777. sequence_output = outputs[0]
  1778. logits = self.qa_outputs(inputs=sequence_output)
  1779. start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
  1780. start_logits = tf.squeeze(input=start_logits, axis=-1)
  1781. end_logits = tf.squeeze(input=end_logits, axis=-1)
  1782. loss = None
  1783. if start_positions is not None and end_positions is not None:
  1784. labels = {"start_position": start_positions}
  1785. labels["end_position"] = end_positions
  1786. loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
  1787. if not return_dict:
  1788. output = (start_logits, end_logits) + outputs[2:]
  1789. return ((loss,) + output) if loss is not None else output
  1790. return TFQuestionAnsweringModelOutput(
  1791. loss=loss,
  1792. start_logits=start_logits,
  1793. end_logits=end_logits,
  1794. hidden_states=outputs.hidden_states,
  1795. attentions=outputs.attentions,
  1796. )
  1797. def build(self, input_shape=None):
  1798. if self.built:
  1799. return
  1800. self.built = True
  1801. if getattr(self, "bert", None) is not None:
  1802. with tf.name_scope(self.bert.name):
  1803. self.bert.build(None)
  1804. if getattr(self, "qa_outputs", None) is not None:
  1805. with tf.name_scope(self.qa_outputs.name):
  1806. self.qa_outputs.build([None, None, self.config.hidden_size])