modeling_tf_mobilebert.py 82 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF 2.0 MobileBERT model."""
  17. from __future__ import annotations
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Tuple, Union
  21. import numpy as np
  22. import tensorflow as tf
  23. from ...activations_tf import get_tf_activation
  24. from ...modeling_tf_outputs import (
  25. TFBaseModelOutput,
  26. TFBaseModelOutputWithPooling,
  27. TFMaskedLMOutput,
  28. TFMultipleChoiceModelOutput,
  29. TFNextSentencePredictorOutput,
  30. TFQuestionAnsweringModelOutput,
  31. TFSequenceClassifierOutput,
  32. TFTokenClassifierOutput,
  33. )
  34. from ...modeling_tf_utils import (
  35. TFMaskedLanguageModelingLoss,
  36. TFModelInputType,
  37. TFMultipleChoiceLoss,
  38. TFNextSentencePredictionLoss,
  39. TFPreTrainedModel,
  40. TFQuestionAnsweringLoss,
  41. TFSequenceClassificationLoss,
  42. TFTokenClassificationLoss,
  43. get_initializer,
  44. keras,
  45. keras_serializable,
  46. unpack_inputs,
  47. )
  48. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  49. from ...utils import (
  50. ModelOutput,
  51. add_code_sample_docstrings,
  52. add_start_docstrings,
  53. add_start_docstrings_to_model_forward,
  54. logging,
  55. replace_return_docstrings,
  56. )
  57. from .configuration_mobilebert import MobileBertConfig
  58. logger = logging.get_logger(__name__)
  59. _CHECKPOINT_FOR_DOC = "google/mobilebert-uncased"
  60. _CONFIG_FOR_DOC = "MobileBertConfig"
  61. # TokenClassification docstring
  62. _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "vumichien/mobilebert-finetuned-ner"
  63. _TOKEN_CLASS_EXPECTED_OUTPUT = "['I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC']"
  64. _TOKEN_CLASS_EXPECTED_LOSS = 0.03
  65. # QuestionAnswering docstring
  66. _CHECKPOINT_FOR_QA = "vumichien/mobilebert-uncased-squad-v2"
  67. _QA_EXPECTED_OUTPUT = "'a nice puppet'"
  68. _QA_EXPECTED_LOSS = 3.98
  69. _QA_TARGET_START_INDEX = 12
  70. _QA_TARGET_END_INDEX = 13
  71. # SequenceClassification docstring
  72. _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "vumichien/emo-mobilebert"
  73. _SEQ_CLASS_EXPECTED_OUTPUT = "'others'"
  74. _SEQ_CLASS_EXPECTED_LOSS = "4.72"
  75. # Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainingLoss
  76. class TFMobileBertPreTrainingLoss:
  77. """
  78. Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining
  79. NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
  80. computation.
  81. """
  82. def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
  83. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  84. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  85. unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
  86. # make sure only labels that are not equal to -100
  87. # are taken into account for the loss computation
  88. lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
  89. masked_lm_losses = unmasked_lm_losses * lm_loss_mask
  90. reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
  91. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  92. unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
  93. ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
  94. masked_ns_loss = unmasked_ns_loss * ns_loss_mask
  95. reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask)
  96. return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,))
  97. class TFMobileBertIntermediate(keras.layers.Layer):
  98. def __init__(self, config, **kwargs):
  99. super().__init__(**kwargs)
  100. self.dense = keras.layers.Dense(config.intermediate_size, name="dense")
  101. if isinstance(config.hidden_act, str):
  102. self.intermediate_act_fn = get_tf_activation(config.hidden_act)
  103. else:
  104. self.intermediate_act_fn = config.hidden_act
  105. self.config = config
  106. def call(self, hidden_states):
  107. hidden_states = self.dense(hidden_states)
  108. hidden_states = self.intermediate_act_fn(hidden_states)
  109. return hidden_states
  110. def build(self, input_shape=None):
  111. if self.built:
  112. return
  113. self.built = True
  114. if getattr(self, "dense", None) is not None:
  115. with tf.name_scope(self.dense.name):
  116. self.dense.build([None, None, self.config.true_hidden_size])
  117. class TFLayerNorm(keras.layers.LayerNormalization):
  118. def __init__(self, feat_size, *args, **kwargs):
  119. self.feat_size = feat_size
  120. super().__init__(*args, **kwargs)
  121. def build(self, input_shape=None):
  122. super().build([None, None, self.feat_size])
  123. class TFNoNorm(keras.layers.Layer):
  124. def __init__(self, feat_size, epsilon=None, **kwargs):
  125. super().__init__(**kwargs)
  126. self.feat_size = feat_size
  127. def build(self, input_shape):
  128. self.bias = self.add_weight("bias", shape=[self.feat_size], initializer="zeros")
  129. self.weight = self.add_weight("weight", shape=[self.feat_size], initializer="ones")
  130. super().build(input_shape)
  131. def call(self, inputs: tf.Tensor):
  132. return inputs * self.weight + self.bias
  133. NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm}
  134. class TFMobileBertEmbeddings(keras.layers.Layer):
  135. """Construct the embeddings from word, position and token_type embeddings."""
  136. def __init__(self, config, **kwargs):
  137. super().__init__(**kwargs)
  138. self.trigram_input = config.trigram_input
  139. self.embedding_size = config.embedding_size
  140. self.config = config
  141. self.hidden_size = config.hidden_size
  142. self.max_position_embeddings = config.max_position_embeddings
  143. self.initializer_range = config.initializer_range
  144. self.embedding_transformation = keras.layers.Dense(config.hidden_size, name="embedding_transformation")
  145. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  146. # any TensorFlow checkpoint file
  147. self.LayerNorm = NORM2FN[config.normalization_type](
  148. config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  149. )
  150. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  151. self.embedded_input_size = self.embedding_size * (3 if self.trigram_input else 1)
  152. def build(self, input_shape=None):
  153. with tf.name_scope("word_embeddings"):
  154. self.weight = self.add_weight(
  155. name="weight",
  156. shape=[self.config.vocab_size, self.embedding_size],
  157. initializer=get_initializer(initializer_range=self.initializer_range),
  158. )
  159. with tf.name_scope("token_type_embeddings"):
  160. self.token_type_embeddings = self.add_weight(
  161. name="embeddings",
  162. shape=[self.config.type_vocab_size, self.hidden_size],
  163. initializer=get_initializer(initializer_range=self.initializer_range),
  164. )
  165. with tf.name_scope("position_embeddings"):
  166. self.position_embeddings = self.add_weight(
  167. name="embeddings",
  168. shape=[self.max_position_embeddings, self.hidden_size],
  169. initializer=get_initializer(initializer_range=self.initializer_range),
  170. )
  171. if self.built:
  172. return
  173. self.built = True
  174. if getattr(self, "embedding_transformation", None) is not None:
  175. with tf.name_scope(self.embedding_transformation.name):
  176. self.embedding_transformation.build([None, None, self.embedded_input_size])
  177. if getattr(self, "LayerNorm", None) is not None:
  178. with tf.name_scope(self.LayerNorm.name):
  179. self.LayerNorm.build(None)
  180. def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
  181. """
  182. Applies embedding based on inputs tensor.
  183. Returns:
  184. final_embeddings (`tf.Tensor`): output embedding tensor.
  185. """
  186. assert not (input_ids is None and inputs_embeds is None)
  187. if input_ids is not None:
  188. check_embeddings_within_bounds(input_ids, self.config.vocab_size)
  189. inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
  190. input_shape = shape_list(inputs_embeds)[:-1]
  191. if token_type_ids is None:
  192. token_type_ids = tf.fill(dims=input_shape, value=0)
  193. if self.trigram_input:
  194. # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
  195. # Devices (https://arxiv.org/abs/2004.02984)
  196. #
  197. # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
  198. # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
  199. # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
  200. # dimensional output.
  201. inputs_embeds = tf.concat(
  202. [
  203. tf.pad(inputs_embeds[:, 1:], ((0, 0), (0, 1), (0, 0))),
  204. inputs_embeds,
  205. tf.pad(inputs_embeds[:, :-1], ((0, 0), (1, 0), (0, 0))),
  206. ],
  207. axis=2,
  208. )
  209. if self.trigram_input or self.embedding_size != self.hidden_size:
  210. inputs_embeds = self.embedding_transformation(inputs_embeds)
  211. if position_ids is None:
  212. position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
  213. position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
  214. token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
  215. final_embeddings = inputs_embeds + position_embeds + token_type_embeds
  216. final_embeddings = self.LayerNorm(inputs=final_embeddings)
  217. final_embeddings = self.dropout(inputs=final_embeddings, training=training)
  218. return final_embeddings
  219. class TFMobileBertSelfAttention(keras.layers.Layer):
  220. def __init__(self, config, **kwargs):
  221. super().__init__(**kwargs)
  222. if config.hidden_size % config.num_attention_heads != 0:
  223. raise ValueError(
  224. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  225. f"heads ({config.num_attention_heads}"
  226. )
  227. self.num_attention_heads = config.num_attention_heads
  228. self.output_attentions = config.output_attentions
  229. assert config.hidden_size % config.num_attention_heads == 0
  230. self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
  231. self.all_head_size = self.num_attention_heads * self.attention_head_size
  232. self.query = keras.layers.Dense(
  233. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
  234. )
  235. self.key = keras.layers.Dense(
  236. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
  237. )
  238. self.value = keras.layers.Dense(
  239. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
  240. )
  241. self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
  242. self.config = config
  243. def transpose_for_scores(self, x, batch_size):
  244. # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
  245. x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
  246. return tf.transpose(x, perm=[0, 2, 1, 3])
  247. def call(
  248. self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False
  249. ):
  250. batch_size = shape_list(attention_mask)[0]
  251. mixed_query_layer = self.query(query_tensor)
  252. mixed_key_layer = self.key(key_tensor)
  253. mixed_value_layer = self.value(value_tensor)
  254. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  255. key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
  256. value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
  257. # Take the dot product between "query" and "key" to get the raw attention scores.
  258. attention_scores = tf.matmul(
  259. query_layer, key_layer, transpose_b=True
  260. ) # (batch size, num_heads, seq_len_q, seq_len_k)
  261. dk = tf.cast(shape_list(key_layer)[-1], dtype=attention_scores.dtype) # scale attention_scores
  262. attention_scores = attention_scores / tf.math.sqrt(dk)
  263. if attention_mask is not None:
  264. # Apply the attention mask is (precomputed for all layers in TFMobileBertModel call() function)
  265. attention_mask = tf.cast(attention_mask, dtype=attention_scores.dtype)
  266. attention_scores = attention_scores + attention_mask
  267. # Normalize the attention scores to probabilities.
  268. attention_probs = stable_softmax(attention_scores, axis=-1)
  269. # This is actually dropping out entire tokens to attend to, which might
  270. # seem a bit unusual, but is taken from the original Transformer paper.
  271. attention_probs = self.dropout(attention_probs, training=training)
  272. # Mask heads if we want to
  273. if head_mask is not None:
  274. attention_probs = attention_probs * head_mask
  275. context_layer = tf.matmul(attention_probs, value_layer)
  276. context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
  277. context_layer = tf.reshape(
  278. context_layer, (batch_size, -1, self.all_head_size)
  279. ) # (batch_size, seq_len_q, all_head_size)
  280. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  281. return outputs
  282. def build(self, input_shape=None):
  283. if self.built:
  284. return
  285. self.built = True
  286. if getattr(self, "query", None) is not None:
  287. with tf.name_scope(self.query.name):
  288. self.query.build([None, None, self.config.true_hidden_size])
  289. if getattr(self, "key", None) is not None:
  290. with tf.name_scope(self.key.name):
  291. self.key.build([None, None, self.config.true_hidden_size])
  292. if getattr(self, "value", None) is not None:
  293. with tf.name_scope(self.value.name):
  294. self.value.build(
  295. [
  296. None,
  297. None,
  298. self.config.true_hidden_size
  299. if self.config.use_bottleneck_attention
  300. else self.config.hidden_size,
  301. ]
  302. )
  303. class TFMobileBertSelfOutput(keras.layers.Layer):
  304. def __init__(self, config, **kwargs):
  305. super().__init__(**kwargs)
  306. self.use_bottleneck = config.use_bottleneck
  307. self.dense = keras.layers.Dense(
  308. config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  309. )
  310. self.LayerNorm = NORM2FN[config.normalization_type](
  311. config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  312. )
  313. if not self.use_bottleneck:
  314. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  315. self.config = config
  316. def call(self, hidden_states, residual_tensor, training=False):
  317. hidden_states = self.dense(hidden_states)
  318. if not self.use_bottleneck:
  319. hidden_states = self.dropout(hidden_states, training=training)
  320. hidden_states = self.LayerNorm(hidden_states + residual_tensor)
  321. return hidden_states
  322. def build(self, input_shape=None):
  323. if self.built:
  324. return
  325. self.built = True
  326. if getattr(self, "dense", None) is not None:
  327. with tf.name_scope(self.dense.name):
  328. self.dense.build([None, None, self.config.true_hidden_size])
  329. if getattr(self, "LayerNorm", None) is not None:
  330. with tf.name_scope(self.LayerNorm.name):
  331. self.LayerNorm.build(None)
  332. class TFMobileBertAttention(keras.layers.Layer):
  333. def __init__(self, config, **kwargs):
  334. super().__init__(**kwargs)
  335. self.self = TFMobileBertSelfAttention(config, name="self")
  336. self.mobilebert_output = TFMobileBertSelfOutput(config, name="output")
  337. def prune_heads(self, heads):
  338. raise NotImplementedError
  339. def call(
  340. self,
  341. query_tensor,
  342. key_tensor,
  343. value_tensor,
  344. layer_input,
  345. attention_mask,
  346. head_mask,
  347. output_attentions,
  348. training=False,
  349. ):
  350. self_outputs = self.self(
  351. query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training
  352. )
  353. attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training)
  354. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  355. return outputs
  356. def build(self, input_shape=None):
  357. if self.built:
  358. return
  359. self.built = True
  360. if getattr(self, "self", None) is not None:
  361. with tf.name_scope(self.self.name):
  362. self.self.build(None)
  363. if getattr(self, "mobilebert_output", None) is not None:
  364. with tf.name_scope(self.mobilebert_output.name):
  365. self.mobilebert_output.build(None)
  366. class TFOutputBottleneck(keras.layers.Layer):
  367. def __init__(self, config, **kwargs):
  368. super().__init__(**kwargs)
  369. self.dense = keras.layers.Dense(config.hidden_size, name="dense")
  370. self.LayerNorm = NORM2FN[config.normalization_type](
  371. config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  372. )
  373. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  374. self.config = config
  375. def call(self, hidden_states, residual_tensor, training=False):
  376. layer_outputs = self.dense(hidden_states)
  377. layer_outputs = self.dropout(layer_outputs, training=training)
  378. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  379. return layer_outputs
  380. def build(self, input_shape=None):
  381. if self.built:
  382. return
  383. self.built = True
  384. if getattr(self, "dense", None) is not None:
  385. with tf.name_scope(self.dense.name):
  386. self.dense.build([None, None, self.config.true_hidden_size])
  387. if getattr(self, "LayerNorm", None) is not None:
  388. with tf.name_scope(self.LayerNorm.name):
  389. self.LayerNorm.build(None)
  390. class TFMobileBertOutput(keras.layers.Layer):
  391. def __init__(self, config, **kwargs):
  392. super().__init__(**kwargs)
  393. self.use_bottleneck = config.use_bottleneck
  394. self.dense = keras.layers.Dense(
  395. config.true_hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  396. )
  397. self.LayerNorm = NORM2FN[config.normalization_type](
  398. config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  399. )
  400. if not self.use_bottleneck:
  401. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  402. else:
  403. self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
  404. self.config = config
  405. def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
  406. hidden_states = self.dense(hidden_states)
  407. if not self.use_bottleneck:
  408. hidden_states = self.dropout(hidden_states, training=training)
  409. hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
  410. else:
  411. hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
  412. hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
  413. return hidden_states
  414. def build(self, input_shape=None):
  415. if self.built:
  416. return
  417. self.built = True
  418. if getattr(self, "dense", None) is not None:
  419. with tf.name_scope(self.dense.name):
  420. self.dense.build([None, None, self.config.intermediate_size])
  421. if getattr(self, "LayerNorm", None) is not None:
  422. with tf.name_scope(self.LayerNorm.name):
  423. self.LayerNorm.build(None)
  424. if getattr(self, "bottleneck", None) is not None:
  425. with tf.name_scope(self.bottleneck.name):
  426. self.bottleneck.build(None)
  427. class TFBottleneckLayer(keras.layers.Layer):
  428. def __init__(self, config, **kwargs):
  429. super().__init__(**kwargs)
  430. self.dense = keras.layers.Dense(config.intra_bottleneck_size, name="dense")
  431. self.LayerNorm = NORM2FN[config.normalization_type](
  432. config.intra_bottleneck_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  433. )
  434. self.config = config
  435. def call(self, inputs):
  436. hidden_states = self.dense(inputs)
  437. hidden_states = self.LayerNorm(hidden_states)
  438. return hidden_states
  439. def build(self, input_shape=None):
  440. if self.built:
  441. return
  442. self.built = True
  443. if getattr(self, "dense", None) is not None:
  444. with tf.name_scope(self.dense.name):
  445. self.dense.build([None, None, self.config.hidden_size])
  446. if getattr(self, "LayerNorm", None) is not None:
  447. with tf.name_scope(self.LayerNorm.name):
  448. self.LayerNorm.build(None)
  449. class TFBottleneck(keras.layers.Layer):
  450. def __init__(self, config, **kwargs):
  451. super().__init__(**kwargs)
  452. self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
  453. self.use_bottleneck_attention = config.use_bottleneck_attention
  454. self.bottleneck_input = TFBottleneckLayer(config, name="input")
  455. if self.key_query_shared_bottleneck:
  456. self.attention = TFBottleneckLayer(config, name="attention")
  457. def call(self, hidden_states):
  458. # This method can return three different tuples of values. These different values make use of bottlenecks,
  459. # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
  460. # usage. These linear layer have weights that are learned during training.
  461. #
  462. # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the
  463. # key, query, value, and "layer input" to be used by the attention layer.
  464. # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor
  465. # in the attention self output, after the attention scores have been computed.
  466. #
  467. # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return
  468. # four values, three of which have been passed through a bottleneck: the query and key, passed through the same
  469. # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck.
  470. #
  471. # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck,
  472. # and the residual layer will be this value passed through a bottleneck.
  473. bottlenecked_hidden_states = self.bottleneck_input(hidden_states)
  474. if self.use_bottleneck_attention:
  475. return (bottlenecked_hidden_states,) * 4
  476. elif self.key_query_shared_bottleneck:
  477. shared_attention_input = self.attention(hidden_states)
  478. return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
  479. else:
  480. return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
  481. def build(self, input_shape=None):
  482. if self.built:
  483. return
  484. self.built = True
  485. if getattr(self, "bottleneck_input", None) is not None:
  486. with tf.name_scope(self.bottleneck_input.name):
  487. self.bottleneck_input.build(None)
  488. if getattr(self, "attention", None) is not None:
  489. with tf.name_scope(self.attention.name):
  490. self.attention.build(None)
  491. class TFFFNOutput(keras.layers.Layer):
  492. def __init__(self, config, **kwargs):
  493. super().__init__(**kwargs)
  494. self.dense = keras.layers.Dense(config.true_hidden_size, name="dense")
  495. self.LayerNorm = NORM2FN[config.normalization_type](
  496. config.true_hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm"
  497. )
  498. self.config = config
  499. def call(self, hidden_states, residual_tensor):
  500. hidden_states = self.dense(hidden_states)
  501. hidden_states = self.LayerNorm(hidden_states + residual_tensor)
  502. return hidden_states
  503. def build(self, input_shape=None):
  504. if self.built:
  505. return
  506. self.built = True
  507. if getattr(self, "dense", None) is not None:
  508. with tf.name_scope(self.dense.name):
  509. self.dense.build([None, None, self.config.intermediate_size])
  510. if getattr(self, "LayerNorm", None) is not None:
  511. with tf.name_scope(self.LayerNorm.name):
  512. self.LayerNorm.build(None)
  513. class TFFFNLayer(keras.layers.Layer):
  514. def __init__(self, config, **kwargs):
  515. super().__init__(**kwargs)
  516. self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
  517. self.mobilebert_output = TFFFNOutput(config, name="output")
  518. def call(self, hidden_states):
  519. intermediate_output = self.intermediate(hidden_states)
  520. layer_outputs = self.mobilebert_output(intermediate_output, hidden_states)
  521. return layer_outputs
  522. def build(self, input_shape=None):
  523. if self.built:
  524. return
  525. self.built = True
  526. if getattr(self, "intermediate", None) is not None:
  527. with tf.name_scope(self.intermediate.name):
  528. self.intermediate.build(None)
  529. if getattr(self, "mobilebert_output", None) is not None:
  530. with tf.name_scope(self.mobilebert_output.name):
  531. self.mobilebert_output.build(None)
  532. class TFMobileBertLayer(keras.layers.Layer):
  533. def __init__(self, config, **kwargs):
  534. super().__init__(**kwargs)
  535. self.use_bottleneck = config.use_bottleneck
  536. self.num_feedforward_networks = config.num_feedforward_networks
  537. self.attention = TFMobileBertAttention(config, name="attention")
  538. self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
  539. self.mobilebert_output = TFMobileBertOutput(config, name="output")
  540. if self.use_bottleneck:
  541. self.bottleneck = TFBottleneck(config, name="bottleneck")
  542. if config.num_feedforward_networks > 1:
  543. self.ffn = [TFFFNLayer(config, name=f"ffn.{i}") for i in range(config.num_feedforward_networks - 1)]
  544. def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
  545. if self.use_bottleneck:
  546. query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
  547. else:
  548. query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
  549. attention_outputs = self.attention(
  550. query_tensor,
  551. key_tensor,
  552. value_tensor,
  553. layer_input,
  554. attention_mask,
  555. head_mask,
  556. output_attentions,
  557. training=training,
  558. )
  559. attention_output = attention_outputs[0]
  560. s = (attention_output,)
  561. if self.num_feedforward_networks != 1:
  562. for i, ffn_module in enumerate(self.ffn):
  563. attention_output = ffn_module(attention_output)
  564. s += (attention_output,)
  565. intermediate_output = self.intermediate(attention_output)
  566. layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training)
  567. outputs = (
  568. (layer_output,)
  569. + attention_outputs[1:]
  570. + (
  571. tf.constant(0),
  572. query_tensor,
  573. key_tensor,
  574. value_tensor,
  575. layer_input,
  576. attention_output,
  577. intermediate_output,
  578. )
  579. + s
  580. ) # add attentions if we output them
  581. return outputs
  582. def build(self, input_shape=None):
  583. if self.built:
  584. return
  585. self.built = True
  586. if getattr(self, "attention", None) is not None:
  587. with tf.name_scope(self.attention.name):
  588. self.attention.build(None)
  589. if getattr(self, "intermediate", None) is not None:
  590. with tf.name_scope(self.intermediate.name):
  591. self.intermediate.build(None)
  592. if getattr(self, "mobilebert_output", None) is not None:
  593. with tf.name_scope(self.mobilebert_output.name):
  594. self.mobilebert_output.build(None)
  595. if getattr(self, "bottleneck", None) is not None:
  596. with tf.name_scope(self.bottleneck.name):
  597. self.bottleneck.build(None)
  598. if getattr(self, "ffn", None) is not None:
  599. for layer in self.ffn:
  600. with tf.name_scope(layer.name):
  601. layer.build(None)
  602. class TFMobileBertEncoder(keras.layers.Layer):
  603. def __init__(self, config, **kwargs):
  604. super().__init__(**kwargs)
  605. self.output_attentions = config.output_attentions
  606. self.output_hidden_states = config.output_hidden_states
  607. self.layer = [TFMobileBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
  608. def call(
  609. self,
  610. hidden_states,
  611. attention_mask,
  612. head_mask,
  613. output_attentions,
  614. output_hidden_states,
  615. return_dict,
  616. training=False,
  617. ):
  618. all_hidden_states = () if output_hidden_states else None
  619. all_attentions = () if output_attentions else None
  620. for i, layer_module in enumerate(self.layer):
  621. if output_hidden_states:
  622. all_hidden_states = all_hidden_states + (hidden_states,)
  623. layer_outputs = layer_module(
  624. hidden_states, attention_mask, head_mask[i], output_attentions, training=training
  625. )
  626. hidden_states = layer_outputs[0]
  627. if output_attentions:
  628. all_attentions = all_attentions + (layer_outputs[1],)
  629. # Add last layer
  630. if output_hidden_states:
  631. all_hidden_states = all_hidden_states + (hidden_states,)
  632. if not return_dict:
  633. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  634. return TFBaseModelOutput(
  635. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  636. )
  637. def build(self, input_shape=None):
  638. if self.built:
  639. return
  640. self.built = True
  641. if getattr(self, "layer", None) is not None:
  642. for layer in self.layer:
  643. with tf.name_scope(layer.name):
  644. layer.build(None)
  645. class TFMobileBertPooler(keras.layers.Layer):
  646. def __init__(self, config, **kwargs):
  647. super().__init__(**kwargs)
  648. self.do_activate = config.classifier_activation
  649. if self.do_activate:
  650. self.dense = keras.layers.Dense(
  651. config.hidden_size,
  652. kernel_initializer=get_initializer(config.initializer_range),
  653. activation="tanh",
  654. name="dense",
  655. )
  656. self.config = config
  657. def call(self, hidden_states):
  658. # We "pool" the model by simply taking the hidden state corresponding
  659. # to the first token.
  660. first_token_tensor = hidden_states[:, 0]
  661. if not self.do_activate:
  662. return first_token_tensor
  663. else:
  664. pooled_output = self.dense(first_token_tensor)
  665. return pooled_output
  666. def build(self, input_shape=None):
  667. if self.built:
  668. return
  669. self.built = True
  670. if getattr(self, "dense", None) is not None:
  671. with tf.name_scope(self.dense.name):
  672. self.dense.build([None, None, self.config.hidden_size])
  673. class TFMobileBertPredictionHeadTransform(keras.layers.Layer):
  674. def __init__(self, config, **kwargs):
  675. super().__init__(**kwargs)
  676. self.dense = keras.layers.Dense(
  677. config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  678. )
  679. if isinstance(config.hidden_act, str):
  680. self.transform_act_fn = get_tf_activation(config.hidden_act)
  681. else:
  682. self.transform_act_fn = config.hidden_act
  683. self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, epsilon=config.layer_norm_eps, name="LayerNorm")
  684. self.config = config
  685. def call(self, hidden_states):
  686. hidden_states = self.dense(hidden_states)
  687. hidden_states = self.transform_act_fn(hidden_states)
  688. hidden_states = self.LayerNorm(hidden_states)
  689. return hidden_states
  690. def build(self, input_shape=None):
  691. if self.built:
  692. return
  693. self.built = True
  694. if getattr(self, "dense", None) is not None:
  695. with tf.name_scope(self.dense.name):
  696. self.dense.build([None, None, self.config.hidden_size])
  697. if getattr(self, "LayerNorm", None) is not None:
  698. with tf.name_scope(self.LayerNorm.name):
  699. self.LayerNorm.build(None)
  700. class TFMobileBertLMPredictionHead(keras.layers.Layer):
  701. def __init__(self, config, **kwargs):
  702. super().__init__(**kwargs)
  703. self.transform = TFMobileBertPredictionHeadTransform(config, name="transform")
  704. self.config = config
  705. def build(self, input_shape=None):
  706. self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
  707. self.dense = self.add_weight(
  708. shape=(self.config.hidden_size - self.config.embedding_size, self.config.vocab_size),
  709. initializer="zeros",
  710. trainable=True,
  711. name="dense/weight",
  712. )
  713. self.decoder = self.add_weight(
  714. shape=(self.config.vocab_size, self.config.embedding_size),
  715. initializer="zeros",
  716. trainable=True,
  717. name="decoder/weight",
  718. )
  719. if self.built:
  720. return
  721. self.built = True
  722. if getattr(self, "transform", None) is not None:
  723. with tf.name_scope(self.transform.name):
  724. self.transform.build(None)
  725. def get_output_embeddings(self):
  726. return self
  727. def set_output_embeddings(self, value):
  728. self.decoder = value
  729. self.config.vocab_size = shape_list(value)[0]
  730. def get_bias(self):
  731. return {"bias": self.bias}
  732. def set_bias(self, value):
  733. self.bias = value["bias"]
  734. self.config.vocab_size = shape_list(value["bias"])[0]
  735. def call(self, hidden_states):
  736. hidden_states = self.transform(hidden_states)
  737. hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0))
  738. hidden_states = hidden_states + self.bias
  739. return hidden_states
  740. class TFMobileBertMLMHead(keras.layers.Layer):
  741. def __init__(self, config, **kwargs):
  742. super().__init__(**kwargs)
  743. self.predictions = TFMobileBertLMPredictionHead(config, name="predictions")
  744. def call(self, sequence_output):
  745. prediction_scores = self.predictions(sequence_output)
  746. return prediction_scores
  747. def build(self, input_shape=None):
  748. if self.built:
  749. return
  750. self.built = True
  751. if getattr(self, "predictions", None) is not None:
  752. with tf.name_scope(self.predictions.name):
  753. self.predictions.build(None)
  754. @keras_serializable
  755. class TFMobileBertMainLayer(keras.layers.Layer):
  756. config_class = MobileBertConfig
  757. def __init__(self, config, add_pooling_layer=True, **kwargs):
  758. super().__init__(**kwargs)
  759. self.config = config
  760. self.num_hidden_layers = config.num_hidden_layers
  761. self.output_attentions = config.output_attentions
  762. self.output_hidden_states = config.output_hidden_states
  763. self.return_dict = config.use_return_dict
  764. self.embeddings = TFMobileBertEmbeddings(config, name="embeddings")
  765. self.encoder = TFMobileBertEncoder(config, name="encoder")
  766. self.pooler = TFMobileBertPooler(config, name="pooler") if add_pooling_layer else None
  767. def get_input_embeddings(self):
  768. return self.embeddings
  769. def set_input_embeddings(self, value):
  770. self.embeddings.weight = value
  771. self.embeddings.vocab_size = shape_list(value)[0]
  772. def _prune_heads(self, heads_to_prune):
  773. """
  774. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  775. class PreTrainedModel
  776. """
  777. raise NotImplementedError
  778. @unpack_inputs
  779. def call(
  780. self,
  781. input_ids=None,
  782. attention_mask=None,
  783. token_type_ids=None,
  784. position_ids=None,
  785. head_mask=None,
  786. inputs_embeds=None,
  787. output_attentions=None,
  788. output_hidden_states=None,
  789. return_dict=None,
  790. training=False,
  791. ):
  792. if input_ids is not None and inputs_embeds is not None:
  793. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  794. elif input_ids is not None:
  795. input_shape = shape_list(input_ids)
  796. elif inputs_embeds is not None:
  797. input_shape = shape_list(inputs_embeds)[:-1]
  798. else:
  799. raise ValueError("You have to specify either input_ids or inputs_embeds")
  800. if attention_mask is None:
  801. attention_mask = tf.fill(input_shape, 1)
  802. if token_type_ids is None:
  803. token_type_ids = tf.fill(input_shape, 0)
  804. embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
  805. # We create a 3D attention mask from a 2D tensor mask.
  806. # Sizes are [batch_size, 1, 1, to_seq_length]
  807. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  808. # this attention mask is more simple than the triangular masking of causal attention
  809. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  810. extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
  811. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  812. # masked positions, this operation will create a tensor which is 0.0 for
  813. # positions we want to attend and -10000.0 for masked positions.
  814. # Since we are adding it to the raw scores before the softmax, this is
  815. # effectively the same as removing these entirely.
  816. extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
  817. one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
  818. ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
  819. extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
  820. # Prepare head mask if needed
  821. # 1.0 in head_mask indicate we keep the head
  822. # attention_probs has shape bsz x n_heads x N x N
  823. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  824. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  825. if head_mask is not None:
  826. raise NotImplementedError
  827. else:
  828. head_mask = [None] * self.num_hidden_layers
  829. encoder_outputs = self.encoder(
  830. embedding_output,
  831. extended_attention_mask,
  832. head_mask,
  833. output_attentions,
  834. output_hidden_states,
  835. return_dict,
  836. training=training,
  837. )
  838. sequence_output = encoder_outputs[0]
  839. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  840. if not return_dict:
  841. return (
  842. sequence_output,
  843. pooled_output,
  844. ) + encoder_outputs[1:]
  845. return TFBaseModelOutputWithPooling(
  846. last_hidden_state=sequence_output,
  847. pooler_output=pooled_output,
  848. hidden_states=encoder_outputs.hidden_states,
  849. attentions=encoder_outputs.attentions,
  850. )
  851. def build(self, input_shape=None):
  852. if self.built:
  853. return
  854. self.built = True
  855. if getattr(self, "embeddings", None) is not None:
  856. with tf.name_scope(self.embeddings.name):
  857. self.embeddings.build(None)
  858. if getattr(self, "encoder", None) is not None:
  859. with tf.name_scope(self.encoder.name):
  860. self.encoder.build(None)
  861. if getattr(self, "pooler", None) is not None:
  862. with tf.name_scope(self.pooler.name):
  863. self.pooler.build(None)
  864. class TFMobileBertPreTrainedModel(TFPreTrainedModel):
  865. """
  866. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  867. models.
  868. """
  869. config_class = MobileBertConfig
  870. base_model_prefix = "mobilebert"
  871. @dataclass
  872. class TFMobileBertForPreTrainingOutput(ModelOutput):
  873. """
  874. Output type of [`TFMobileBertForPreTraining`].
  875. Args:
  876. prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  877. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  878. seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`):
  879. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  880. before SoftMax).
  881. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  882. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  883. `(batch_size, sequence_length, hidden_size)`.
  884. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  885. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  886. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  887. sequence_length)`.
  888. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  889. heads.
  890. """
  891. loss: tf.Tensor | None = None
  892. prediction_logits: tf.Tensor = None
  893. seq_relationship_logits: tf.Tensor = None
  894. hidden_states: Tuple[tf.Tensor] | None = None
  895. attentions: Tuple[tf.Tensor] | None = None
  896. MOBILEBERT_START_DOCSTRING = r"""
  897. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  898. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  899. etc.)
  900. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  901. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  902. behavior.
  903. <Tip>
  904. TensorFlow models and layers in `transformers` accept two formats as input:
  905. - having all inputs as keyword arguments (like PyTorch models), or
  906. - having all inputs as a list, tuple or dict in the first positional argument.
  907. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  908. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  909. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  910. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  911. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  912. positional argument:
  913. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  914. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  915. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  916. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  917. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  918. Note that when creating models and layers with
  919. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  920. about any of this, as you can just pass inputs like you would to any other Python function!
  921. </Tip>
  922. Parameters:
  923. config ([`MobileBertConfig`]): Model configuration class with all the parameters of the model.
  924. Initializing with a config file does not load the weights associated with the model, only the
  925. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  926. """
  927. MOBILEBERT_INPUTS_DOCSTRING = r"""
  928. Args:
  929. input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
  930. Indices of input sequence tokens in the vocabulary.
  931. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  932. [`PreTrainedTokenizer.encode`] for details.
  933. [What are input IDs?](../glossary#input-ids)
  934. attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  935. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  936. - 1 for tokens that are **not masked**,
  937. - 0 for tokens that are **masked**.
  938. [What are attention masks?](../glossary#attention-mask)
  939. token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  940. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  941. 1]`:
  942. - 0 corresponds to a *sentence A* token,
  943. - 1 corresponds to a *sentence B* token.
  944. [What are token type IDs?](../glossary#token-type-ids)
  945. position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  946. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  947. config.max_position_embeddings - 1]`.
  948. [What are position IDs?](../glossary#position-ids)
  949. head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  950. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  951. - 1 indicates the head is **not masked**,
  952. - 0 indicates the head is **masked**.
  953. inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
  954. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  955. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  956. model's internal embedding lookup matrix.
  957. output_attentions (`bool`, *optional*):
  958. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  959. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  960. config will be used instead.
  961. output_hidden_states (`bool`, *optional*):
  962. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  963. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  964. used instead.
  965. return_dict (`bool`, *optional*):
  966. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  967. eager mode, in graph mode the value will always be set to True.
  968. training (`bool`, *optional*, defaults to `False`):
  969. Whether or not to use the model in training mode (some modules like dropout modules have different
  970. behaviors between training and evaluation).
  971. """
  972. @add_start_docstrings(
  973. "The bare MobileBert Model transformer outputting raw hidden-states without any specific head on top.",
  974. MOBILEBERT_START_DOCSTRING,
  975. )
  976. class TFMobileBertModel(TFMobileBertPreTrainedModel):
  977. def __init__(self, config, *inputs, **kwargs):
  978. super().__init__(config, *inputs, **kwargs)
  979. self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
  980. @unpack_inputs
  981. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  982. @add_code_sample_docstrings(
  983. checkpoint=_CHECKPOINT_FOR_DOC,
  984. output_type=TFBaseModelOutputWithPooling,
  985. config_class=_CONFIG_FOR_DOC,
  986. )
  987. def call(
  988. self,
  989. input_ids: TFModelInputType | None = None,
  990. attention_mask: np.ndarray | tf.Tensor | None = None,
  991. token_type_ids: np.ndarray | tf.Tensor | None = None,
  992. position_ids: np.ndarray | tf.Tensor | None = None,
  993. head_mask: np.ndarray | tf.Tensor | None = None,
  994. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  995. output_attentions: Optional[bool] = None,
  996. output_hidden_states: Optional[bool] = None,
  997. return_dict: Optional[bool] = None,
  998. training: Optional[bool] = False,
  999. ) -> Union[Tuple, TFBaseModelOutputWithPooling]:
  1000. outputs = self.mobilebert(
  1001. input_ids=input_ids,
  1002. attention_mask=attention_mask,
  1003. token_type_ids=token_type_ids,
  1004. position_ids=position_ids,
  1005. head_mask=head_mask,
  1006. inputs_embeds=inputs_embeds,
  1007. output_attentions=output_attentions,
  1008. output_hidden_states=output_hidden_states,
  1009. return_dict=return_dict,
  1010. training=training,
  1011. )
  1012. return outputs
  1013. def build(self, input_shape=None):
  1014. if self.built:
  1015. return
  1016. self.built = True
  1017. if getattr(self, "mobilebert", None) is not None:
  1018. with tf.name_scope(self.mobilebert.name):
  1019. self.mobilebert.build(None)
  1020. @add_start_docstrings(
  1021. """
  1022. MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
  1023. `next sentence prediction (classification)` head.
  1024. """,
  1025. MOBILEBERT_START_DOCSTRING,
  1026. )
  1027. class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel, TFMobileBertPreTrainingLoss):
  1028. def __init__(self, config, *inputs, **kwargs):
  1029. super().__init__(config, *inputs, **kwargs)
  1030. self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
  1031. self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
  1032. self.seq_relationship = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls")
  1033. def get_lm_head(self):
  1034. return self.predictions.predictions
  1035. def get_prefix_bias_name(self):
  1036. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1037. return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name
  1038. @unpack_inputs
  1039. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1040. @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  1041. def call(
  1042. self,
  1043. input_ids: TFModelInputType | None = None,
  1044. attention_mask: np.ndarray | tf.Tensor | None = None,
  1045. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1046. position_ids: np.ndarray | tf.Tensor | None = None,
  1047. head_mask: np.ndarray | tf.Tensor | None = None,
  1048. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1049. output_attentions: Optional[bool] = None,
  1050. output_hidden_states: Optional[bool] = None,
  1051. return_dict: Optional[bool] = None,
  1052. labels: np.ndarray | tf.Tensor | None = None,
  1053. next_sentence_label: np.ndarray | tf.Tensor | None = None,
  1054. training: Optional[bool] = False,
  1055. ) -> Union[Tuple, TFMobileBertForPreTrainingOutput]:
  1056. r"""
  1057. Return:
  1058. Examples:
  1059. ```python
  1060. >>> import tensorflow as tf
  1061. >>> from transformers import AutoTokenizer, TFMobileBertForPreTraining
  1062. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  1063. >>> model = TFMobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
  1064. >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
  1065. >>> outputs = model(input_ids)
  1066. >>> prediction_scores, seq_relationship_scores = outputs[:2]
  1067. ```"""
  1068. outputs = self.mobilebert(
  1069. input_ids,
  1070. attention_mask=attention_mask,
  1071. token_type_ids=token_type_ids,
  1072. position_ids=position_ids,
  1073. head_mask=head_mask,
  1074. inputs_embeds=inputs_embeds,
  1075. output_attentions=output_attentions,
  1076. output_hidden_states=output_hidden_states,
  1077. return_dict=return_dict,
  1078. training=training,
  1079. )
  1080. sequence_output, pooled_output = outputs[:2]
  1081. prediction_scores = self.predictions(sequence_output)
  1082. seq_relationship_score = self.seq_relationship(pooled_output)
  1083. total_loss = None
  1084. if labels is not None and next_sentence_label is not None:
  1085. d_labels = {"labels": labels}
  1086. d_labels["next_sentence_label"] = next_sentence_label
  1087. total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
  1088. if not return_dict:
  1089. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  1090. return ((total_loss,) + output) if total_loss is not None else output
  1091. return TFMobileBertForPreTrainingOutput(
  1092. loss=total_loss,
  1093. prediction_logits=prediction_scores,
  1094. seq_relationship_logits=seq_relationship_score,
  1095. hidden_states=outputs.hidden_states,
  1096. attentions=outputs.attentions,
  1097. )
  1098. def build(self, input_shape=None):
  1099. if self.built:
  1100. return
  1101. self.built = True
  1102. if getattr(self, "mobilebert", None) is not None:
  1103. with tf.name_scope(self.mobilebert.name):
  1104. self.mobilebert.build(None)
  1105. if getattr(self, "predictions", None) is not None:
  1106. with tf.name_scope(self.predictions.name):
  1107. self.predictions.build(None)
  1108. if getattr(self, "seq_relationship", None) is not None:
  1109. with tf.name_scope(self.seq_relationship.name):
  1110. self.seq_relationship.build(None)
  1111. def tf_to_pt_weight_rename(self, tf_weight):
  1112. if tf_weight == "cls.predictions.decoder.weight":
  1113. return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
  1114. else:
  1115. return (tf_weight,)
  1116. @add_start_docstrings("""MobileBert Model with a `language modeling` head on top.""", MOBILEBERT_START_DOCSTRING)
  1117. class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
  1118. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1119. _keys_to_ignore_on_load_unexpected = [
  1120. r"pooler",
  1121. r"seq_relationship___cls",
  1122. r"cls.seq_relationship",
  1123. ]
  1124. def __init__(self, config, *inputs, **kwargs):
  1125. super().__init__(config, *inputs, **kwargs)
  1126. self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
  1127. self.predictions = TFMobileBertMLMHead(config, name="predictions___cls")
  1128. def get_lm_head(self):
  1129. return self.predictions.predictions
  1130. def get_prefix_bias_name(self):
  1131. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1132. return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
  1133. @unpack_inputs
  1134. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1135. @add_code_sample_docstrings(
  1136. checkpoint=_CHECKPOINT_FOR_DOC,
  1137. output_type=TFMaskedLMOutput,
  1138. config_class=_CONFIG_FOR_DOC,
  1139. expected_output="'paris'",
  1140. expected_loss=0.57,
  1141. )
  1142. def call(
  1143. self,
  1144. input_ids: TFModelInputType | None = None,
  1145. attention_mask: np.ndarray | tf.Tensor | None = None,
  1146. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1147. position_ids: np.ndarray | tf.Tensor | None = None,
  1148. head_mask: np.ndarray | tf.Tensor | None = None,
  1149. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1150. output_attentions: Optional[bool] = None,
  1151. output_hidden_states: Optional[bool] = None,
  1152. return_dict: Optional[bool] = None,
  1153. labels: np.ndarray | tf.Tensor | None = None,
  1154. training: Optional[bool] = False,
  1155. ) -> Union[Tuple, TFMaskedLMOutput]:
  1156. r"""
  1157. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1158. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1159. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1160. loss is only computed for the tokens with labels
  1161. """
  1162. outputs = self.mobilebert(
  1163. input_ids,
  1164. attention_mask=attention_mask,
  1165. token_type_ids=token_type_ids,
  1166. position_ids=position_ids,
  1167. head_mask=head_mask,
  1168. inputs_embeds=inputs_embeds,
  1169. output_attentions=output_attentions,
  1170. output_hidden_states=output_hidden_states,
  1171. return_dict=return_dict,
  1172. training=training,
  1173. )
  1174. sequence_output = outputs[0]
  1175. prediction_scores = self.predictions(sequence_output, training=training)
  1176. loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
  1177. if not return_dict:
  1178. output = (prediction_scores,) + outputs[2:]
  1179. return ((loss,) + output) if loss is not None else output
  1180. return TFMaskedLMOutput(
  1181. loss=loss,
  1182. logits=prediction_scores,
  1183. hidden_states=outputs.hidden_states,
  1184. attentions=outputs.attentions,
  1185. )
  1186. def build(self, input_shape=None):
  1187. if self.built:
  1188. return
  1189. self.built = True
  1190. if getattr(self, "mobilebert", None) is not None:
  1191. with tf.name_scope(self.mobilebert.name):
  1192. self.mobilebert.build(None)
  1193. if getattr(self, "predictions", None) is not None:
  1194. with tf.name_scope(self.predictions.name):
  1195. self.predictions.build(None)
  1196. def tf_to_pt_weight_rename(self, tf_weight):
  1197. if tf_weight == "cls.predictions.decoder.weight":
  1198. return tf_weight, "mobilebert.embeddings.word_embeddings.weight"
  1199. else:
  1200. return (tf_weight,)
  1201. class TFMobileBertOnlyNSPHead(keras.layers.Layer):
  1202. def __init__(self, config, **kwargs):
  1203. super().__init__(**kwargs)
  1204. self.seq_relationship = keras.layers.Dense(2, name="seq_relationship")
  1205. self.config = config
  1206. def call(self, pooled_output):
  1207. seq_relationship_score = self.seq_relationship(pooled_output)
  1208. return seq_relationship_score
  1209. def build(self, input_shape=None):
  1210. if self.built:
  1211. return
  1212. self.built = True
  1213. if getattr(self, "seq_relationship", None) is not None:
  1214. with tf.name_scope(self.seq_relationship.name):
  1215. self.seq_relationship.build([None, None, self.config.hidden_size])
  1216. @add_start_docstrings(
  1217. """MobileBert Model with a `next sentence prediction (classification)` head on top.""",
  1218. MOBILEBERT_START_DOCSTRING,
  1219. )
  1220. class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel, TFNextSentencePredictionLoss):
  1221. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1222. _keys_to_ignore_on_load_unexpected = [r"predictions___cls", r"cls.predictions"]
  1223. def __init__(self, config, *inputs, **kwargs):
  1224. super().__init__(config, *inputs, **kwargs)
  1225. self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
  1226. self.cls = TFMobileBertOnlyNSPHead(config, name="seq_relationship___cls")
  1227. @unpack_inputs
  1228. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1229. @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
  1230. def call(
  1231. self,
  1232. input_ids: TFModelInputType | None = None,
  1233. attention_mask: np.ndarray | tf.Tensor | None = None,
  1234. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1235. position_ids: np.ndarray | tf.Tensor | None = None,
  1236. head_mask: np.ndarray | tf.Tensor | None = None,
  1237. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1238. output_attentions: Optional[bool] = None,
  1239. output_hidden_states: Optional[bool] = None,
  1240. return_dict: Optional[bool] = None,
  1241. next_sentence_label: np.ndarray | tf.Tensor | None = None,
  1242. training: Optional[bool] = False,
  1243. ) -> Union[Tuple, TFNextSentencePredictorOutput]:
  1244. r"""
  1245. Return:
  1246. Examples:
  1247. ```python
  1248. >>> import tensorflow as tf
  1249. >>> from transformers import AutoTokenizer, TFMobileBertForNextSentencePrediction
  1250. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  1251. >>> model = TFMobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")
  1252. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1253. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  1254. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf")
  1255. >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0]
  1256. ```"""
  1257. outputs = self.mobilebert(
  1258. input_ids,
  1259. attention_mask=attention_mask,
  1260. token_type_ids=token_type_ids,
  1261. position_ids=position_ids,
  1262. head_mask=head_mask,
  1263. inputs_embeds=inputs_embeds,
  1264. output_attentions=output_attentions,
  1265. output_hidden_states=output_hidden_states,
  1266. return_dict=return_dict,
  1267. training=training,
  1268. )
  1269. pooled_output = outputs[1]
  1270. seq_relationship_scores = self.cls(pooled_output)
  1271. next_sentence_loss = (
  1272. None
  1273. if next_sentence_label is None
  1274. else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
  1275. )
  1276. if not return_dict:
  1277. output = (seq_relationship_scores,) + outputs[2:]
  1278. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  1279. return TFNextSentencePredictorOutput(
  1280. loss=next_sentence_loss,
  1281. logits=seq_relationship_scores,
  1282. hidden_states=outputs.hidden_states,
  1283. attentions=outputs.attentions,
  1284. )
  1285. def build(self, input_shape=None):
  1286. if self.built:
  1287. return
  1288. self.built = True
  1289. if getattr(self, "mobilebert", None) is not None:
  1290. with tf.name_scope(self.mobilebert.name):
  1291. self.mobilebert.build(None)
  1292. if getattr(self, "cls", None) is not None:
  1293. with tf.name_scope(self.cls.name):
  1294. self.cls.build(None)
  1295. @add_start_docstrings(
  1296. """
  1297. MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  1298. pooled output) e.g. for GLUE tasks.
  1299. """,
  1300. MOBILEBERT_START_DOCSTRING,
  1301. )
  1302. class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSequenceClassificationLoss):
  1303. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1304. _keys_to_ignore_on_load_unexpected = [
  1305. r"predictions___cls",
  1306. r"seq_relationship___cls",
  1307. r"cls.predictions",
  1308. r"cls.seq_relationship",
  1309. ]
  1310. _keys_to_ignore_on_load_missing = [r"dropout"]
  1311. def __init__(self, config, *inputs, **kwargs):
  1312. super().__init__(config, *inputs, **kwargs)
  1313. self.num_labels = config.num_labels
  1314. self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
  1315. classifier_dropout = (
  1316. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1317. )
  1318. self.dropout = keras.layers.Dropout(classifier_dropout)
  1319. self.classifier = keras.layers.Dense(
  1320. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1321. )
  1322. self.config = config
  1323. @unpack_inputs
  1324. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1325. @add_code_sample_docstrings(
  1326. checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
  1327. output_type=TFSequenceClassifierOutput,
  1328. config_class=_CONFIG_FOR_DOC,
  1329. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1330. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1331. )
  1332. def call(
  1333. self,
  1334. input_ids: TFModelInputType | None = None,
  1335. attention_mask: np.ndarray | tf.Tensor | None = None,
  1336. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1337. position_ids: np.ndarray | tf.Tensor | None = None,
  1338. head_mask: np.ndarray | tf.Tensor | None = None,
  1339. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1340. output_attentions: Optional[bool] = None,
  1341. output_hidden_states: Optional[bool] = None,
  1342. return_dict: Optional[bool] = None,
  1343. labels: np.ndarray | tf.Tensor | None = None,
  1344. training: Optional[bool] = False,
  1345. ) -> Union[Tuple, TFSequenceClassifierOutput]:
  1346. r"""
  1347. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1348. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1349. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1350. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1351. """
  1352. outputs = self.mobilebert(
  1353. input_ids,
  1354. attention_mask=attention_mask,
  1355. token_type_ids=token_type_ids,
  1356. position_ids=position_ids,
  1357. head_mask=head_mask,
  1358. inputs_embeds=inputs_embeds,
  1359. output_attentions=output_attentions,
  1360. output_hidden_states=output_hidden_states,
  1361. return_dict=return_dict,
  1362. training=training,
  1363. )
  1364. pooled_output = outputs[1]
  1365. pooled_output = self.dropout(pooled_output, training=training)
  1366. logits = self.classifier(pooled_output)
  1367. loss = None if labels is None else self.hf_compute_loss(labels, logits)
  1368. if not return_dict:
  1369. output = (logits,) + outputs[2:]
  1370. return ((loss,) + output) if loss is not None else output
  1371. return TFSequenceClassifierOutput(
  1372. loss=loss,
  1373. logits=logits,
  1374. hidden_states=outputs.hidden_states,
  1375. attentions=outputs.attentions,
  1376. )
  1377. def build(self, input_shape=None):
  1378. if self.built:
  1379. return
  1380. self.built = True
  1381. if getattr(self, "mobilebert", None) is not None:
  1382. with tf.name_scope(self.mobilebert.name):
  1383. self.mobilebert.build(None)
  1384. if getattr(self, "classifier", None) is not None:
  1385. with tf.name_scope(self.classifier.name):
  1386. self.classifier.build([None, None, self.config.hidden_size])
  1387. @add_start_docstrings(
  1388. """
  1389. MobileBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
  1390. linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1391. """,
  1392. MOBILEBERT_START_DOCSTRING,
  1393. )
  1394. class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
  1395. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1396. _keys_to_ignore_on_load_unexpected = [
  1397. r"pooler",
  1398. r"predictions___cls",
  1399. r"seq_relationship___cls",
  1400. r"cls.predictions",
  1401. r"cls.seq_relationship",
  1402. ]
  1403. def __init__(self, config, *inputs, **kwargs):
  1404. super().__init__(config, *inputs, **kwargs)
  1405. self.num_labels = config.num_labels
  1406. self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
  1407. self.qa_outputs = keras.layers.Dense(
  1408. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
  1409. )
  1410. self.config = config
  1411. @unpack_inputs
  1412. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1413. @add_code_sample_docstrings(
  1414. checkpoint=_CHECKPOINT_FOR_QA,
  1415. output_type=TFQuestionAnsweringModelOutput,
  1416. config_class=_CONFIG_FOR_DOC,
  1417. qa_target_start_index=_QA_TARGET_START_INDEX,
  1418. qa_target_end_index=_QA_TARGET_END_INDEX,
  1419. expected_output=_QA_EXPECTED_OUTPUT,
  1420. expected_loss=_QA_EXPECTED_LOSS,
  1421. )
  1422. def call(
  1423. self,
  1424. input_ids: TFModelInputType | None = None,
  1425. attention_mask: np.ndarray | tf.Tensor | None = None,
  1426. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1427. position_ids: np.ndarray | tf.Tensor | None = None,
  1428. head_mask: np.ndarray | tf.Tensor | None = None,
  1429. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1430. output_attentions: Optional[bool] = None,
  1431. output_hidden_states: Optional[bool] = None,
  1432. return_dict: Optional[bool] = None,
  1433. start_positions: np.ndarray | tf.Tensor | None = None,
  1434. end_positions: np.ndarray | tf.Tensor | None = None,
  1435. training: Optional[bool] = False,
  1436. ) -> Union[Tuple, TFQuestionAnsweringModelOutput]:
  1437. r"""
  1438. start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1439. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1440. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1441. are not taken into account for computing the loss.
  1442. end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1443. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1444. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1445. are not taken into account for computing the loss.
  1446. """
  1447. outputs = self.mobilebert(
  1448. input_ids,
  1449. attention_mask=attention_mask,
  1450. token_type_ids=token_type_ids,
  1451. position_ids=position_ids,
  1452. head_mask=head_mask,
  1453. inputs_embeds=inputs_embeds,
  1454. output_attentions=output_attentions,
  1455. output_hidden_states=output_hidden_states,
  1456. return_dict=return_dict,
  1457. training=training,
  1458. )
  1459. sequence_output = outputs[0]
  1460. logits = self.qa_outputs(sequence_output)
  1461. start_logits, end_logits = tf.split(logits, 2, axis=-1)
  1462. start_logits = tf.squeeze(start_logits, axis=-1)
  1463. end_logits = tf.squeeze(end_logits, axis=-1)
  1464. loss = None
  1465. if start_positions is not None and end_positions is not None:
  1466. labels = {"start_position": start_positions, "end_position": end_positions}
  1467. loss = self.hf_compute_loss(labels, (start_logits, end_logits))
  1468. if not return_dict:
  1469. output = (start_logits, end_logits) + outputs[2:]
  1470. return ((loss,) + output) if loss is not None else output
  1471. return TFQuestionAnsweringModelOutput(
  1472. loss=loss,
  1473. start_logits=start_logits,
  1474. end_logits=end_logits,
  1475. hidden_states=outputs.hidden_states,
  1476. attentions=outputs.attentions,
  1477. )
  1478. def build(self, input_shape=None):
  1479. if self.built:
  1480. return
  1481. self.built = True
  1482. if getattr(self, "mobilebert", None) is not None:
  1483. with tf.name_scope(self.mobilebert.name):
  1484. self.mobilebert.build(None)
  1485. if getattr(self, "qa_outputs", None) is not None:
  1486. with tf.name_scope(self.qa_outputs.name):
  1487. self.qa_outputs.build([None, None, self.config.hidden_size])
  1488. @add_start_docstrings(
  1489. """
  1490. MobileBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
  1491. a softmax) e.g. for RocStories/SWAG tasks.
  1492. """,
  1493. MOBILEBERT_START_DOCSTRING,
  1494. )
  1495. class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoiceLoss):
  1496. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1497. _keys_to_ignore_on_load_unexpected = [
  1498. r"predictions___cls",
  1499. r"seq_relationship___cls",
  1500. r"cls.predictions",
  1501. r"cls.seq_relationship",
  1502. ]
  1503. _keys_to_ignore_on_load_missing = [r"dropout"]
  1504. def __init__(self, config, *inputs, **kwargs):
  1505. super().__init__(config, *inputs, **kwargs)
  1506. self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert")
  1507. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  1508. self.classifier = keras.layers.Dense(
  1509. 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1510. )
  1511. self.config = config
  1512. @unpack_inputs
  1513. @add_start_docstrings_to_model_forward(
  1514. MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  1515. )
  1516. @add_code_sample_docstrings(
  1517. checkpoint=_CHECKPOINT_FOR_DOC,
  1518. output_type=TFMultipleChoiceModelOutput,
  1519. config_class=_CONFIG_FOR_DOC,
  1520. )
  1521. def call(
  1522. self,
  1523. input_ids: TFModelInputType | None = None,
  1524. attention_mask: np.ndarray | tf.Tensor | None = None,
  1525. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1526. position_ids: np.ndarray | tf.Tensor | None = None,
  1527. head_mask: np.ndarray | tf.Tensor | None = None,
  1528. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1529. output_attentions: Optional[bool] = None,
  1530. output_hidden_states: Optional[bool] = None,
  1531. return_dict: Optional[bool] = None,
  1532. labels: np.ndarray | tf.Tensor | None = None,
  1533. training: Optional[bool] = False,
  1534. ) -> Union[Tuple, TFMultipleChoiceModelOutput]:
  1535. r"""
  1536. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1537. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  1538. where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
  1539. """
  1540. if input_ids is not None:
  1541. num_choices = shape_list(input_ids)[1]
  1542. seq_length = shape_list(input_ids)[2]
  1543. else:
  1544. num_choices = shape_list(inputs_embeds)[1]
  1545. seq_length = shape_list(inputs_embeds)[2]
  1546. flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
  1547. flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
  1548. flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
  1549. flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
  1550. flat_inputs_embeds = (
  1551. tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
  1552. if inputs_embeds is not None
  1553. else None
  1554. )
  1555. outputs = self.mobilebert(
  1556. flat_input_ids,
  1557. flat_attention_mask,
  1558. flat_token_type_ids,
  1559. flat_position_ids,
  1560. head_mask,
  1561. flat_inputs_embeds,
  1562. output_attentions,
  1563. output_hidden_states,
  1564. return_dict=return_dict,
  1565. training=training,
  1566. )
  1567. pooled_output = outputs[1]
  1568. pooled_output = self.dropout(pooled_output, training=training)
  1569. logits = self.classifier(pooled_output)
  1570. reshaped_logits = tf.reshape(logits, (-1, num_choices))
  1571. loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
  1572. if not return_dict:
  1573. output = (reshaped_logits,) + outputs[2:]
  1574. return ((loss,) + output) if loss is not None else output
  1575. return TFMultipleChoiceModelOutput(
  1576. loss=loss,
  1577. logits=reshaped_logits,
  1578. hidden_states=outputs.hidden_states,
  1579. attentions=outputs.attentions,
  1580. )
  1581. def build(self, input_shape=None):
  1582. if self.built:
  1583. return
  1584. self.built = True
  1585. if getattr(self, "mobilebert", None) is not None:
  1586. with tf.name_scope(self.mobilebert.name):
  1587. self.mobilebert.build(None)
  1588. if getattr(self, "classifier", None) is not None:
  1589. with tf.name_scope(self.classifier.name):
  1590. self.classifier.build([None, None, self.config.hidden_size])
  1591. @add_start_docstrings(
  1592. """
  1593. MobileBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
  1594. for Named-Entity-Recognition (NER) tasks.
  1595. """,
  1596. MOBILEBERT_START_DOCSTRING,
  1597. )
  1598. class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
  1599. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1600. _keys_to_ignore_on_load_unexpected = [
  1601. r"pooler",
  1602. r"predictions___cls",
  1603. r"seq_relationship___cls",
  1604. r"cls.predictions",
  1605. r"cls.seq_relationship",
  1606. ]
  1607. _keys_to_ignore_on_load_missing = [r"dropout"]
  1608. def __init__(self, config, *inputs, **kwargs):
  1609. super().__init__(config, *inputs, **kwargs)
  1610. self.num_labels = config.num_labels
  1611. self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert")
  1612. classifier_dropout = (
  1613. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1614. )
  1615. self.dropout = keras.layers.Dropout(classifier_dropout)
  1616. self.classifier = keras.layers.Dense(
  1617. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1618. )
  1619. self.config = config
  1620. @unpack_inputs
  1621. @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1622. @add_code_sample_docstrings(
  1623. checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
  1624. output_type=TFTokenClassifierOutput,
  1625. config_class=_CONFIG_FOR_DOC,
  1626. expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
  1627. expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
  1628. )
  1629. def call(
  1630. self,
  1631. input_ids: TFModelInputType | None = None,
  1632. attention_mask: np.ndarray | tf.Tensor | None = None,
  1633. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1634. position_ids: np.ndarray | tf.Tensor | None = None,
  1635. head_mask: np.ndarray | tf.Tensor | None = None,
  1636. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1637. output_attentions: Optional[bool] = None,
  1638. output_hidden_states: Optional[bool] = None,
  1639. return_dict: Optional[bool] = None,
  1640. labels: np.ndarray | tf.Tensor | None = None,
  1641. training: Optional[bool] = False,
  1642. ) -> Union[Tuple, TFTokenClassifierOutput]:
  1643. r"""
  1644. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1645. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1646. """
  1647. outputs = self.mobilebert(
  1648. input_ids,
  1649. attention_mask=attention_mask,
  1650. token_type_ids=token_type_ids,
  1651. position_ids=position_ids,
  1652. head_mask=head_mask,
  1653. inputs_embeds=inputs_embeds,
  1654. output_attentions=output_attentions,
  1655. output_hidden_states=output_hidden_states,
  1656. return_dict=return_dict,
  1657. training=training,
  1658. )
  1659. sequence_output = outputs[0]
  1660. sequence_output = self.dropout(sequence_output, training=training)
  1661. logits = self.classifier(sequence_output)
  1662. loss = None if labels is None else self.hf_compute_loss(labels, logits)
  1663. if not return_dict:
  1664. output = (logits,) + outputs[2:]
  1665. return ((loss,) + output) if loss is not None else output
  1666. return TFTokenClassifierOutput(
  1667. loss=loss,
  1668. logits=logits,
  1669. hidden_states=outputs.hidden_states,
  1670. attentions=outputs.attentions,
  1671. )
  1672. def build(self, input_shape=None):
  1673. if self.built:
  1674. return
  1675. self.built = True
  1676. if getattr(self, "mobilebert", None) is not None:
  1677. with tf.name_scope(self.mobilebert.name):
  1678. self.mobilebert.build(None)
  1679. if getattr(self, "classifier", None) is not None:
  1680. with tf.name_scope(self.classifier.name):
  1681. self.classifier.build([None, None, self.config.hidden_size])