modeling_tf_albert.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573
  1. # coding=utf-8
  2. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF 2.0 ALBERT model."""
  17. from __future__ import annotations
  18. import math
  19. from dataclasses import dataclass
  20. from typing import Dict, Optional, Tuple, Union
  21. import numpy as np
  22. import tensorflow as tf
  23. from ...activations_tf import get_tf_activation
  24. from ...modeling_tf_outputs import (
  25. TFBaseModelOutput,
  26. TFBaseModelOutputWithPooling,
  27. TFMaskedLMOutput,
  28. TFMultipleChoiceModelOutput,
  29. TFQuestionAnsweringModelOutput,
  30. TFSequenceClassifierOutput,
  31. TFTokenClassifierOutput,
  32. )
  33. from ...modeling_tf_utils import (
  34. TFMaskedLanguageModelingLoss,
  35. TFModelInputType,
  36. TFMultipleChoiceLoss,
  37. TFPreTrainedModel,
  38. TFQuestionAnsweringLoss,
  39. TFSequenceClassificationLoss,
  40. TFTokenClassificationLoss,
  41. get_initializer,
  42. keras,
  43. keras_serializable,
  44. unpack_inputs,
  45. )
  46. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  47. from ...utils import (
  48. ModelOutput,
  49. add_code_sample_docstrings,
  50. add_start_docstrings,
  51. add_start_docstrings_to_model_forward,
  52. logging,
  53. replace_return_docstrings,
  54. )
  55. from .configuration_albert import AlbertConfig
  56. logger = logging.get_logger(__name__)
  57. _CHECKPOINT_FOR_DOC = "albert/albert-base-v2"
  58. _CONFIG_FOR_DOC = "AlbertConfig"
  59. class TFAlbertPreTrainingLoss:
  60. """
  61. Loss function suitable for ALBERT pretraining, that is, the task of pretraining a language model by combining SOP +
  62. MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  63. """
  64. def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
  65. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  66. if self.config.tf_legacy_loss:
  67. # make sure only labels that are not equal to -100
  68. # are taken into account as loss
  69. masked_lm_active_loss = tf.not_equal(tf.reshape(tensor=labels["labels"], shape=(-1,)), -100)
  70. masked_lm_reduced_logits = tf.boolean_mask(
  71. tensor=tf.reshape(tensor=logits[0], shape=(-1, shape_list(logits[0])[2])),
  72. mask=masked_lm_active_loss,
  73. )
  74. masked_lm_labels = tf.boolean_mask(
  75. tensor=tf.reshape(tensor=labels["labels"], shape=(-1,)), mask=masked_lm_active_loss
  76. )
  77. sentence_order_active_loss = tf.not_equal(
  78. tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), -100
  79. )
  80. sentence_order_reduced_logits = tf.boolean_mask(
  81. tensor=tf.reshape(tensor=logits[1], shape=(-1, 2)), mask=sentence_order_active_loss
  82. )
  83. sentence_order_label = tf.boolean_mask(
  84. tensor=tf.reshape(tensor=labels["sentence_order_label"], shape=(-1,)), mask=sentence_order_active_loss
  85. )
  86. masked_lm_loss = loss_fn(y_true=masked_lm_labels, y_pred=masked_lm_reduced_logits)
  87. sentence_order_loss = loss_fn(y_true=sentence_order_label, y_pred=sentence_order_reduced_logits)
  88. masked_lm_loss = tf.reshape(tensor=masked_lm_loss, shape=(-1, shape_list(sentence_order_loss)[0]))
  89. masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss, axis=0)
  90. return masked_lm_loss + sentence_order_loss
  91. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  92. unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
  93. # make sure only labels that are not equal to -100
  94. # are taken into account for the loss computation
  95. lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
  96. masked_lm_losses = unmasked_lm_losses * lm_loss_mask
  97. reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
  98. sop_logits = tf.reshape(logits[1], (-1, 2))
  99. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  100. unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
  101. sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)
  102. masked_sop_loss = unmasked_sop_loss * sop_loss_mask
  103. reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)
  104. return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))
  105. class TFAlbertEmbeddings(keras.layers.Layer):
  106. """Construct the embeddings from word, position and token_type embeddings."""
  107. def __init__(self, config: AlbertConfig, **kwargs):
  108. super().__init__(**kwargs)
  109. self.config = config
  110. self.embedding_size = config.embedding_size
  111. self.max_position_embeddings = config.max_position_embeddings
  112. self.initializer_range = config.initializer_range
  113. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  114. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  115. def build(self, input_shape=None):
  116. with tf.name_scope("word_embeddings"):
  117. self.weight = self.add_weight(
  118. name="weight",
  119. shape=[self.config.vocab_size, self.embedding_size],
  120. initializer=get_initializer(self.initializer_range),
  121. )
  122. with tf.name_scope("token_type_embeddings"):
  123. self.token_type_embeddings = self.add_weight(
  124. name="embeddings",
  125. shape=[self.config.type_vocab_size, self.embedding_size],
  126. initializer=get_initializer(self.initializer_range),
  127. )
  128. with tf.name_scope("position_embeddings"):
  129. self.position_embeddings = self.add_weight(
  130. name="embeddings",
  131. shape=[self.max_position_embeddings, self.embedding_size],
  132. initializer=get_initializer(self.initializer_range),
  133. )
  134. if self.built:
  135. return
  136. self.built = True
  137. if getattr(self, "LayerNorm", None) is not None:
  138. with tf.name_scope(self.LayerNorm.name):
  139. self.LayerNorm.build([None, None, self.config.embedding_size])
  140. # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
  141. def call(
  142. self,
  143. input_ids: tf.Tensor = None,
  144. position_ids: tf.Tensor = None,
  145. token_type_ids: tf.Tensor = None,
  146. inputs_embeds: tf.Tensor = None,
  147. past_key_values_length=0,
  148. training: bool = False,
  149. ) -> tf.Tensor:
  150. """
  151. Applies embedding based on inputs tensor.
  152. Returns:
  153. final_embeddings (`tf.Tensor`): output embedding tensor.
  154. """
  155. if input_ids is None and inputs_embeds is None:
  156. raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
  157. if input_ids is not None:
  158. check_embeddings_within_bounds(input_ids, self.config.vocab_size)
  159. inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
  160. input_shape = shape_list(inputs_embeds)[:-1]
  161. if token_type_ids is None:
  162. token_type_ids = tf.fill(dims=input_shape, value=0)
  163. if position_ids is None:
  164. position_ids = tf.expand_dims(
  165. tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
  166. )
  167. position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
  168. token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
  169. final_embeddings = inputs_embeds + position_embeds + token_type_embeds
  170. final_embeddings = self.LayerNorm(inputs=final_embeddings)
  171. final_embeddings = self.dropout(inputs=final_embeddings, training=training)
  172. return final_embeddings
  173. class TFAlbertAttention(keras.layers.Layer):
  174. """Contains the complete attention sublayer, including both dropouts and layer norm."""
  175. def __init__(self, config: AlbertConfig, **kwargs):
  176. super().__init__(**kwargs)
  177. if config.hidden_size % config.num_attention_heads != 0:
  178. raise ValueError(
  179. f"The hidden size ({config.hidden_size}) is not a multiple of the number "
  180. f"of attention heads ({config.num_attention_heads})"
  181. )
  182. self.num_attention_heads = config.num_attention_heads
  183. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  184. self.all_head_size = self.num_attention_heads * self.attention_head_size
  185. self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
  186. self.output_attentions = config.output_attentions
  187. self.query = keras.layers.Dense(
  188. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
  189. )
  190. self.key = keras.layers.Dense(
  191. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
  192. )
  193. self.value = keras.layers.Dense(
  194. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
  195. )
  196. self.dense = keras.layers.Dense(
  197. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  198. )
  199. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  200. # Two different dropout probabilities; see https://github.com/google-research/albert/blob/master/modeling.py#L971-L993
  201. self.attention_dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
  202. self.output_dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  203. self.config = config
  204. def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
  205. # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
  206. tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
  207. # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
  208. return tf.transpose(tensor, perm=[0, 2, 1, 3])
  209. def call(
  210. self,
  211. input_tensor: tf.Tensor,
  212. attention_mask: tf.Tensor,
  213. head_mask: tf.Tensor,
  214. output_attentions: bool,
  215. training: bool = False,
  216. ) -> Tuple[tf.Tensor]:
  217. batch_size = shape_list(input_tensor)[0]
  218. mixed_query_layer = self.query(inputs=input_tensor)
  219. mixed_key_layer = self.key(inputs=input_tensor)
  220. mixed_value_layer = self.value(inputs=input_tensor)
  221. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  222. key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
  223. value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
  224. # Take the dot product between "query" and "key" to get the raw attention scores.
  225. # (batch size, num_heads, seq_len_q, seq_len_k)
  226. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  227. dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
  228. attention_scores = tf.divide(attention_scores, dk)
  229. if attention_mask is not None:
  230. # Apply the attention mask is (precomputed for all layers in TFAlbertModel call() function)
  231. attention_scores = tf.add(attention_scores, attention_mask)
  232. # Normalize the attention scores to probabilities.
  233. attention_probs = stable_softmax(logits=attention_scores, axis=-1)
  234. # This is actually dropping out entire tokens to attend to, which might
  235. # seem a bit unusual, but is taken from the original Transformer paper.
  236. attention_probs = self.attention_dropout(inputs=attention_probs, training=training)
  237. # Mask heads if we want to
  238. if head_mask is not None:
  239. attention_probs = tf.multiply(attention_probs, head_mask)
  240. context_layer = tf.matmul(attention_probs, value_layer)
  241. context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
  242. # (batch_size, seq_len_q, all_head_size)
  243. context_layer = tf.reshape(tensor=context_layer, shape=(batch_size, -1, self.all_head_size))
  244. self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  245. hidden_states = self_outputs[0]
  246. hidden_states = self.dense(inputs=hidden_states)
  247. hidden_states = self.output_dropout(inputs=hidden_states, training=training)
  248. attention_output = self.LayerNorm(inputs=hidden_states + input_tensor)
  249. # add attentions if we output them
  250. outputs = (attention_output,) + self_outputs[1:]
  251. return outputs
  252. def build(self, input_shape=None):
  253. if self.built:
  254. return
  255. self.built = True
  256. if getattr(self, "query", None) is not None:
  257. with tf.name_scope(self.query.name):
  258. self.query.build([None, None, self.config.hidden_size])
  259. if getattr(self, "key", None) is not None:
  260. with tf.name_scope(self.key.name):
  261. self.key.build([None, None, self.config.hidden_size])
  262. if getattr(self, "value", None) is not None:
  263. with tf.name_scope(self.value.name):
  264. self.value.build([None, None, self.config.hidden_size])
  265. if getattr(self, "dense", None) is not None:
  266. with tf.name_scope(self.dense.name):
  267. self.dense.build([None, None, self.config.hidden_size])
  268. if getattr(self, "LayerNorm", None) is not None:
  269. with tf.name_scope(self.LayerNorm.name):
  270. self.LayerNorm.build([None, None, self.config.hidden_size])
  271. class TFAlbertLayer(keras.layers.Layer):
  272. def __init__(self, config: AlbertConfig, **kwargs):
  273. super().__init__(**kwargs)
  274. self.attention = TFAlbertAttention(config, name="attention")
  275. self.ffn = keras.layers.Dense(
  276. units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
  277. )
  278. if isinstance(config.hidden_act, str):
  279. self.activation = get_tf_activation(config.hidden_act)
  280. else:
  281. self.activation = config.hidden_act
  282. self.ffn_output = keras.layers.Dense(
  283. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn_output"
  284. )
  285. self.full_layer_layer_norm = keras.layers.LayerNormalization(
  286. epsilon=config.layer_norm_eps, name="full_layer_layer_norm"
  287. )
  288. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  289. self.config = config
  290. def call(
  291. self,
  292. hidden_states: tf.Tensor,
  293. attention_mask: tf.Tensor,
  294. head_mask: tf.Tensor,
  295. output_attentions: bool,
  296. training: bool = False,
  297. ) -> Tuple[tf.Tensor]:
  298. attention_outputs = self.attention(
  299. input_tensor=hidden_states,
  300. attention_mask=attention_mask,
  301. head_mask=head_mask,
  302. output_attentions=output_attentions,
  303. training=training,
  304. )
  305. ffn_output = self.ffn(inputs=attention_outputs[0])
  306. ffn_output = self.activation(ffn_output)
  307. ffn_output = self.ffn_output(inputs=ffn_output)
  308. ffn_output = self.dropout(inputs=ffn_output, training=training)
  309. hidden_states = self.full_layer_layer_norm(inputs=ffn_output + attention_outputs[0])
  310. # add attentions if we output them
  311. outputs = (hidden_states,) + attention_outputs[1:]
  312. return outputs
  313. def build(self, input_shape=None):
  314. if self.built:
  315. return
  316. self.built = True
  317. if getattr(self, "attention", None) is not None:
  318. with tf.name_scope(self.attention.name):
  319. self.attention.build(None)
  320. if getattr(self, "ffn", None) is not None:
  321. with tf.name_scope(self.ffn.name):
  322. self.ffn.build([None, None, self.config.hidden_size])
  323. if getattr(self, "ffn_output", None) is not None:
  324. with tf.name_scope(self.ffn_output.name):
  325. self.ffn_output.build([None, None, self.config.intermediate_size])
  326. if getattr(self, "full_layer_layer_norm", None) is not None:
  327. with tf.name_scope(self.full_layer_layer_norm.name):
  328. self.full_layer_layer_norm.build([None, None, self.config.hidden_size])
  329. class TFAlbertLayerGroup(keras.layers.Layer):
  330. def __init__(self, config: AlbertConfig, **kwargs):
  331. super().__init__(**kwargs)
  332. self.albert_layers = [
  333. TFAlbertLayer(config, name=f"albert_layers_._{i}") for i in range(config.inner_group_num)
  334. ]
  335. def call(
  336. self,
  337. hidden_states: tf.Tensor,
  338. attention_mask: tf.Tensor,
  339. head_mask: tf.Tensor,
  340. output_attentions: bool,
  341. output_hidden_states: bool,
  342. training: bool = False,
  343. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  344. layer_hidden_states = () if output_hidden_states else None
  345. layer_attentions = () if output_attentions else None
  346. for layer_index, albert_layer in enumerate(self.albert_layers):
  347. if output_hidden_states:
  348. layer_hidden_states = layer_hidden_states + (hidden_states,)
  349. layer_output = albert_layer(
  350. hidden_states=hidden_states,
  351. attention_mask=attention_mask,
  352. head_mask=head_mask[layer_index],
  353. output_attentions=output_attentions,
  354. training=training,
  355. )
  356. hidden_states = layer_output[0]
  357. if output_attentions:
  358. layer_attentions = layer_attentions + (layer_output[1],)
  359. # Add last layer
  360. if output_hidden_states:
  361. layer_hidden_states = layer_hidden_states + (hidden_states,)
  362. return tuple(v for v in [hidden_states, layer_hidden_states, layer_attentions] if v is not None)
  363. def build(self, input_shape=None):
  364. if self.built:
  365. return
  366. self.built = True
  367. if getattr(self, "albert_layers", None) is not None:
  368. for layer in self.albert_layers:
  369. with tf.name_scope(layer.name):
  370. layer.build(None)
  371. class TFAlbertTransformer(keras.layers.Layer):
  372. def __init__(self, config: AlbertConfig, **kwargs):
  373. super().__init__(**kwargs)
  374. self.num_hidden_layers = config.num_hidden_layers
  375. self.num_hidden_groups = config.num_hidden_groups
  376. # Number of layers in a hidden group
  377. self.layers_per_group = int(config.num_hidden_layers / config.num_hidden_groups)
  378. self.embedding_hidden_mapping_in = keras.layers.Dense(
  379. units=config.hidden_size,
  380. kernel_initializer=get_initializer(config.initializer_range),
  381. name="embedding_hidden_mapping_in",
  382. )
  383. self.albert_layer_groups = [
  384. TFAlbertLayerGroup(config, name=f"albert_layer_groups_._{i}") for i in range(config.num_hidden_groups)
  385. ]
  386. self.config = config
  387. def call(
  388. self,
  389. hidden_states: tf.Tensor,
  390. attention_mask: tf.Tensor,
  391. head_mask: tf.Tensor,
  392. output_attentions: bool,
  393. output_hidden_states: bool,
  394. return_dict: bool,
  395. training: bool = False,
  396. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  397. hidden_states = self.embedding_hidden_mapping_in(inputs=hidden_states)
  398. all_attentions = () if output_attentions else None
  399. all_hidden_states = (hidden_states,) if output_hidden_states else None
  400. for i in range(self.num_hidden_layers):
  401. # Index of the hidden group
  402. group_idx = int(i / (self.num_hidden_layers / self.num_hidden_groups))
  403. layer_group_output = self.albert_layer_groups[group_idx](
  404. hidden_states=hidden_states,
  405. attention_mask=attention_mask,
  406. head_mask=head_mask[group_idx * self.layers_per_group : (group_idx + 1) * self.layers_per_group],
  407. output_attentions=output_attentions,
  408. output_hidden_states=output_hidden_states,
  409. training=training,
  410. )
  411. hidden_states = layer_group_output[0]
  412. if output_attentions:
  413. all_attentions = all_attentions + layer_group_output[-1]
  414. if output_hidden_states:
  415. all_hidden_states = all_hidden_states + (hidden_states,)
  416. if not return_dict:
  417. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  418. return TFBaseModelOutput(
  419. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  420. )
  421. def build(self, input_shape=None):
  422. if self.built:
  423. return
  424. self.built = True
  425. if getattr(self, "embedding_hidden_mapping_in", None) is not None:
  426. with tf.name_scope(self.embedding_hidden_mapping_in.name):
  427. self.embedding_hidden_mapping_in.build([None, None, self.config.embedding_size])
  428. if getattr(self, "albert_layer_groups", None) is not None:
  429. for layer in self.albert_layer_groups:
  430. with tf.name_scope(layer.name):
  431. layer.build(None)
  432. class TFAlbertPreTrainedModel(TFPreTrainedModel):
  433. """
  434. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  435. models.
  436. """
  437. config_class = AlbertConfig
  438. base_model_prefix = "albert"
  439. class TFAlbertMLMHead(keras.layers.Layer):
  440. def __init__(self, config: AlbertConfig, input_embeddings: keras.layers.Layer, **kwargs):
  441. super().__init__(**kwargs)
  442. self.config = config
  443. self.embedding_size = config.embedding_size
  444. self.dense = keras.layers.Dense(
  445. config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  446. )
  447. if isinstance(config.hidden_act, str):
  448. self.activation = get_tf_activation(config.hidden_act)
  449. else:
  450. self.activation = config.hidden_act
  451. self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  452. # The output weights are the same as the input embeddings, but there is
  453. # an output-only bias for each token.
  454. self.decoder = input_embeddings
  455. def build(self, input_shape=None):
  456. self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
  457. self.decoder_bias = self.add_weight(
  458. shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias"
  459. )
  460. if self.built:
  461. return
  462. self.built = True
  463. if getattr(self, "dense", None) is not None:
  464. with tf.name_scope(self.dense.name):
  465. self.dense.build([None, None, self.config.hidden_size])
  466. if getattr(self, "LayerNorm", None) is not None:
  467. with tf.name_scope(self.LayerNorm.name):
  468. self.LayerNorm.build([None, None, self.config.embedding_size])
  469. def get_output_embeddings(self) -> keras.layers.Layer:
  470. return self.decoder
  471. def set_output_embeddings(self, value: tf.Variable):
  472. self.decoder.weight = value
  473. self.decoder.vocab_size = shape_list(value)[0]
  474. def get_bias(self) -> Dict[str, tf.Variable]:
  475. return {"bias": self.bias, "decoder_bias": self.decoder_bias}
  476. def set_bias(self, value: tf.Variable):
  477. self.bias = value["bias"]
  478. self.decoder_bias = value["decoder_bias"]
  479. self.config.vocab_size = shape_list(value["bias"])[0]
  480. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  481. hidden_states = self.dense(inputs=hidden_states)
  482. hidden_states = self.activation(hidden_states)
  483. hidden_states = self.LayerNorm(inputs=hidden_states)
  484. seq_length = shape_list(tensor=hidden_states)[1]
  485. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
  486. hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True)
  487. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
  488. hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.decoder_bias)
  489. return hidden_states
  490. @keras_serializable
  491. class TFAlbertMainLayer(keras.layers.Layer):
  492. config_class = AlbertConfig
  493. def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True, **kwargs):
  494. super().__init__(**kwargs)
  495. self.config = config
  496. self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
  497. self.encoder = TFAlbertTransformer(config, name="encoder")
  498. self.pooler = (
  499. keras.layers.Dense(
  500. units=config.hidden_size,
  501. kernel_initializer=get_initializer(config.initializer_range),
  502. activation="tanh",
  503. name="pooler",
  504. )
  505. if add_pooling_layer
  506. else None
  507. )
  508. def get_input_embeddings(self) -> keras.layers.Layer:
  509. return self.embeddings
  510. def set_input_embeddings(self, value: tf.Variable):
  511. self.embeddings.weight = value
  512. self.embeddings.vocab_size = shape_list(value)[0]
  513. def _prune_heads(self, heads_to_prune):
  514. """
  515. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  516. class PreTrainedModel
  517. """
  518. raise NotImplementedError
  519. @unpack_inputs
  520. def call(
  521. self,
  522. input_ids: TFModelInputType | None = None,
  523. attention_mask: np.ndarray | tf.Tensor | None = None,
  524. token_type_ids: np.ndarray | tf.Tensor | None = None,
  525. position_ids: np.ndarray | tf.Tensor | None = None,
  526. head_mask: np.ndarray | tf.Tensor | None = None,
  527. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  528. output_attentions: Optional[bool] = None,
  529. output_hidden_states: Optional[bool] = None,
  530. return_dict: Optional[bool] = None,
  531. training: bool = False,
  532. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  533. if input_ids is not None and inputs_embeds is not None:
  534. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  535. elif input_ids is not None:
  536. input_shape = shape_list(input_ids)
  537. elif inputs_embeds is not None:
  538. input_shape = shape_list(inputs_embeds)[:-1]
  539. else:
  540. raise ValueError("You have to specify either input_ids or inputs_embeds")
  541. if attention_mask is None:
  542. attention_mask = tf.fill(dims=input_shape, value=1)
  543. if token_type_ids is None:
  544. token_type_ids = tf.fill(dims=input_shape, value=0)
  545. embedding_output = self.embeddings(
  546. input_ids=input_ids,
  547. position_ids=position_ids,
  548. token_type_ids=token_type_ids,
  549. inputs_embeds=inputs_embeds,
  550. training=training,
  551. )
  552. # We create a 3D attention mask from a 2D tensor mask.
  553. # Sizes are [batch_size, 1, 1, to_seq_length]
  554. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  555. # this attention mask is more simple than the triangular masking of causal attention
  556. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  557. extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))
  558. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  559. # masked positions, this operation will create a tensor which is 0.0 for
  560. # positions we want to attend and -10000.0 for masked positions.
  561. # Since we are adding it to the raw scores before the softmax, this is
  562. # effectively the same as removing these entirely.
  563. extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
  564. one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
  565. ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
  566. extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
  567. # Prepare head mask if needed
  568. # 1.0 in head_mask indicate we keep the head
  569. # attention_probs has shape bsz x n_heads x N x N
  570. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  571. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  572. if head_mask is not None:
  573. raise NotImplementedError
  574. else:
  575. head_mask = [None] * self.config.num_hidden_layers
  576. encoder_outputs = self.encoder(
  577. hidden_states=embedding_output,
  578. attention_mask=extended_attention_mask,
  579. head_mask=head_mask,
  580. output_attentions=output_attentions,
  581. output_hidden_states=output_hidden_states,
  582. return_dict=return_dict,
  583. training=training,
  584. )
  585. sequence_output = encoder_outputs[0]
  586. pooled_output = self.pooler(inputs=sequence_output[:, 0]) if self.pooler is not None else None
  587. if not return_dict:
  588. return (
  589. sequence_output,
  590. pooled_output,
  591. ) + encoder_outputs[1:]
  592. return TFBaseModelOutputWithPooling(
  593. last_hidden_state=sequence_output,
  594. pooler_output=pooled_output,
  595. hidden_states=encoder_outputs.hidden_states,
  596. attentions=encoder_outputs.attentions,
  597. )
  598. def build(self, input_shape=None):
  599. if self.built:
  600. return
  601. self.built = True
  602. if getattr(self, "embeddings", None) is not None:
  603. with tf.name_scope(self.embeddings.name):
  604. self.embeddings.build(None)
  605. if getattr(self, "encoder", None) is not None:
  606. with tf.name_scope(self.encoder.name):
  607. self.encoder.build(None)
  608. if getattr(self, "pooler", None) is not None:
  609. with tf.name_scope(self.pooler.name):
  610. self.pooler.build([None, None, self.config.hidden_size])
  611. @dataclass
  612. class TFAlbertForPreTrainingOutput(ModelOutput):
  613. """
  614. Output type of [`TFAlbertForPreTraining`].
  615. Args:
  616. prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  617. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  618. sop_logits (`tf.Tensor` of shape `(batch_size, 2)`):
  619. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  620. before SoftMax).
  621. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  622. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  623. `(batch_size, sequence_length, hidden_size)`.
  624. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  625. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  626. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  627. sequence_length)`.
  628. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  629. heads.
  630. """
  631. loss: tf.Tensor = None
  632. prediction_logits: tf.Tensor = None
  633. sop_logits: tf.Tensor = None
  634. hidden_states: Tuple[tf.Tensor] | None = None
  635. attentions: Tuple[tf.Tensor] | None = None
  636. ALBERT_START_DOCSTRING = r"""
  637. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  638. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  639. etc.)
  640. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  641. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  642. behavior.
  643. <Tip>
  644. TensorFlow models and layers in `transformers` accept two formats as input:
  645. - having all inputs as keyword arguments (like PyTorch models), or
  646. - having all inputs as a list, tuple or dict in the first positional argument.
  647. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  648. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  649. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  650. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  651. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  652. positional argument:
  653. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  654. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  655. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  656. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  657. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  658. Note that when creating models and layers with
  659. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  660. about any of this, as you can just pass inputs like you would to any other Python function!
  661. </Tip>
  662. Args:
  663. config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
  664. Initializing with a config file does not load the weights associated with the model, only the
  665. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  666. """
  667. ALBERT_INPUTS_DOCSTRING = r"""
  668. Args:
  669. input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
  670. Indices of input sequence tokens in the vocabulary.
  671. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  672. [`PreTrainedTokenizer.encode`] for details.
  673. [What are input IDs?](../glossary#input-ids)
  674. attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  675. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  676. - 1 for tokens that are **not masked**,
  677. - 0 for tokens that are **masked**.
  678. [What are attention masks?](../glossary#attention-mask)
  679. token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  680. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  681. 1]`:
  682. - 0 corresponds to a *sentence A* token,
  683. - 1 corresponds to a *sentence B* token.
  684. [What are token type IDs?](../glossary#token-type-ids)
  685. position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  686. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  687. config.max_position_embeddings - 1]`.
  688. [What are position IDs?](../glossary#position-ids)
  689. head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  690. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  691. - 1 indicates the head is **not masked**,
  692. - 0 indicates the head is **masked**.
  693. inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
  694. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  695. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  696. model's internal embedding lookup matrix.
  697. output_attentions (`bool`, *optional*):
  698. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  699. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  700. config will be used instead.
  701. output_hidden_states (`bool`, *optional*):
  702. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  703. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  704. used instead.
  705. return_dict (`bool`, *optional*):
  706. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  707. eager mode, in graph mode the value will always be set to True.
  708. training (`bool`, *optional*, defaults to `False`):
  709. Whether or not to use the model in training mode (some modules like dropout modules have different
  710. behaviors between training and evaluation).
  711. """
  712. @add_start_docstrings(
  713. "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
  714. ALBERT_START_DOCSTRING,
  715. )
  716. class TFAlbertModel(TFAlbertPreTrainedModel):
  717. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  718. super().__init__(config, *inputs, **kwargs)
  719. self.albert = TFAlbertMainLayer(config, name="albert")
  720. @unpack_inputs
  721. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  722. @add_code_sample_docstrings(
  723. checkpoint=_CHECKPOINT_FOR_DOC,
  724. output_type=TFBaseModelOutputWithPooling,
  725. config_class=_CONFIG_FOR_DOC,
  726. )
  727. def call(
  728. self,
  729. input_ids: TFModelInputType | None = None,
  730. attention_mask: np.ndarray | tf.Tensor | None = None,
  731. token_type_ids: np.ndarray | tf.Tensor | None = None,
  732. position_ids: np.ndarray | tf.Tensor | None = None,
  733. head_mask: np.ndarray | tf.Tensor | None = None,
  734. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  735. output_attentions: Optional[bool] = None,
  736. output_hidden_states: Optional[bool] = None,
  737. return_dict: Optional[bool] = None,
  738. training: Optional[bool] = False,
  739. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  740. outputs = self.albert(
  741. input_ids=input_ids,
  742. attention_mask=attention_mask,
  743. token_type_ids=token_type_ids,
  744. position_ids=position_ids,
  745. head_mask=head_mask,
  746. inputs_embeds=inputs_embeds,
  747. output_attentions=output_attentions,
  748. output_hidden_states=output_hidden_states,
  749. return_dict=return_dict,
  750. training=training,
  751. )
  752. return outputs
  753. def build(self, input_shape=None):
  754. if self.built:
  755. return
  756. self.built = True
  757. if getattr(self, "albert", None) is not None:
  758. with tf.name_scope(self.albert.name):
  759. self.albert.build(None)
  760. @add_start_docstrings(
  761. """
  762. Albert Model with two heads on top for pretraining: a `masked language modeling` head and a `sentence order
  763. prediction` (classification) head.
  764. """,
  765. ALBERT_START_DOCSTRING,
  766. )
  767. class TFAlbertForPreTraining(TFAlbertPreTrainedModel, TFAlbertPreTrainingLoss):
  768. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  769. _keys_to_ignore_on_load_unexpected = [r"predictions.decoder.weight"]
  770. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  771. super().__init__(config, *inputs, **kwargs)
  772. self.num_labels = config.num_labels
  773. self.albert = TFAlbertMainLayer(config, name="albert")
  774. self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
  775. self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")
  776. def get_lm_head(self) -> keras.layers.Layer:
  777. return self.predictions
  778. @unpack_inputs
  779. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  780. @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  781. def call(
  782. self,
  783. input_ids: TFModelInputType | None = None,
  784. attention_mask: np.ndarray | tf.Tensor | None = None,
  785. token_type_ids: np.ndarray | tf.Tensor | None = None,
  786. position_ids: np.ndarray | tf.Tensor | None = None,
  787. head_mask: np.ndarray | tf.Tensor | None = None,
  788. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  789. output_attentions: Optional[bool] = None,
  790. output_hidden_states: Optional[bool] = None,
  791. return_dict: Optional[bool] = None,
  792. labels: np.ndarray | tf.Tensor | None = None,
  793. sentence_order_label: np.ndarray | tf.Tensor | None = None,
  794. training: Optional[bool] = False,
  795. ) -> Union[TFAlbertForPreTrainingOutput, Tuple[tf.Tensor]]:
  796. r"""
  797. Return:
  798. Example:
  799. ```python
  800. >>> import tensorflow as tf
  801. >>> from transformers import AutoTokenizer, TFAlbertForPreTraining
  802. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  803. >>> model = TFAlbertForPreTraining.from_pretrained("albert/albert-base-v2")
  804. >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
  805. >>> # Batch size 1
  806. >>> outputs = model(input_ids)
  807. >>> prediction_logits = outputs.prediction_logits
  808. >>> sop_logits = outputs.sop_logits
  809. ```"""
  810. outputs = self.albert(
  811. input_ids=input_ids,
  812. attention_mask=attention_mask,
  813. token_type_ids=token_type_ids,
  814. position_ids=position_ids,
  815. head_mask=head_mask,
  816. inputs_embeds=inputs_embeds,
  817. output_attentions=output_attentions,
  818. output_hidden_states=output_hidden_states,
  819. return_dict=return_dict,
  820. training=training,
  821. )
  822. sequence_output, pooled_output = outputs[:2]
  823. prediction_scores = self.predictions(hidden_states=sequence_output)
  824. sop_scores = self.sop_classifier(pooled_output=pooled_output, training=training)
  825. total_loss = None
  826. if labels is not None and sentence_order_label is not None:
  827. d_labels = {"labels": labels}
  828. d_labels["sentence_order_label"] = sentence_order_label
  829. total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, sop_scores))
  830. if not return_dict:
  831. output = (prediction_scores, sop_scores) + outputs[2:]
  832. return ((total_loss,) + output) if total_loss is not None else output
  833. return TFAlbertForPreTrainingOutput(
  834. loss=total_loss,
  835. prediction_logits=prediction_scores,
  836. sop_logits=sop_scores,
  837. hidden_states=outputs.hidden_states,
  838. attentions=outputs.attentions,
  839. )
  840. def build(self, input_shape=None):
  841. if self.built:
  842. return
  843. self.built = True
  844. if getattr(self, "albert", None) is not None:
  845. with tf.name_scope(self.albert.name):
  846. self.albert.build(None)
  847. if getattr(self, "predictions", None) is not None:
  848. with tf.name_scope(self.predictions.name):
  849. self.predictions.build(None)
  850. if getattr(self, "sop_classifier", None) is not None:
  851. with tf.name_scope(self.sop_classifier.name):
  852. self.sop_classifier.build(None)
  853. class TFAlbertSOPHead(keras.layers.Layer):
  854. def __init__(self, config: AlbertConfig, **kwargs):
  855. super().__init__(**kwargs)
  856. self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
  857. self.classifier = keras.layers.Dense(
  858. units=config.num_labels,
  859. kernel_initializer=get_initializer(config.initializer_range),
  860. name="classifier",
  861. )
  862. self.config = config
  863. def call(self, pooled_output: tf.Tensor, training: bool) -> tf.Tensor:
  864. dropout_pooled_output = self.dropout(inputs=pooled_output, training=training)
  865. logits = self.classifier(inputs=dropout_pooled_output)
  866. return logits
  867. def build(self, input_shape=None):
  868. if self.built:
  869. return
  870. self.built = True
  871. if getattr(self, "classifier", None) is not None:
  872. with tf.name_scope(self.classifier.name):
  873. self.classifier.build([None, None, self.config.hidden_size])
  874. @add_start_docstrings("""Albert Model with a `language modeling` head on top.""", ALBERT_START_DOCSTRING)
  875. class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
  876. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  877. _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions.decoder.weight"]
  878. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  879. super().__init__(config, *inputs, **kwargs)
  880. self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
  881. self.predictions = TFAlbertMLMHead(config, input_embeddings=self.albert.embeddings, name="predictions")
  882. def get_lm_head(self) -> keras.layers.Layer:
  883. return self.predictions
  884. @unpack_inputs
  885. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  886. @replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
  887. def call(
  888. self,
  889. input_ids: TFModelInputType | None = None,
  890. attention_mask: np.ndarray | tf.Tensor | None = None,
  891. token_type_ids: np.ndarray | tf.Tensor | None = None,
  892. position_ids: np.ndarray | tf.Tensor | None = None,
  893. head_mask: np.ndarray | tf.Tensor | None = None,
  894. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  895. output_attentions: Optional[bool] = None,
  896. output_hidden_states: Optional[bool] = None,
  897. return_dict: Optional[bool] = None,
  898. labels: np.ndarray | tf.Tensor | None = None,
  899. training: Optional[bool] = False,
  900. ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
  901. r"""
  902. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  903. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  904. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  905. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  906. Returns:
  907. Example:
  908. ```python
  909. >>> import tensorflow as tf
  910. >>> from transformers import AutoTokenizer, TFAlbertForMaskedLM
  911. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  912. >>> model = TFAlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
  913. >>> # add mask_token
  914. >>> inputs = tokenizer(f"The capital of [MASK] is Paris.", return_tensors="tf")
  915. >>> logits = model(**inputs).logits
  916. >>> # retrieve index of [MASK]
  917. >>> mask_token_index = tf.where(inputs.input_ids == tokenizer.mask_token_id)[0][1]
  918. >>> predicted_token_id = tf.math.argmax(logits[0, mask_token_index], axis=-1)
  919. >>> tokenizer.decode(predicted_token_id)
  920. 'france'
  921. ```
  922. ```python
  923. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
  924. >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  925. >>> outputs = model(**inputs, labels=labels)
  926. >>> round(float(outputs.loss), 2)
  927. 0.81
  928. ```
  929. """
  930. outputs = self.albert(
  931. input_ids=input_ids,
  932. attention_mask=attention_mask,
  933. token_type_ids=token_type_ids,
  934. position_ids=position_ids,
  935. head_mask=head_mask,
  936. inputs_embeds=inputs_embeds,
  937. output_attentions=output_attentions,
  938. output_hidden_states=output_hidden_states,
  939. return_dict=return_dict,
  940. training=training,
  941. )
  942. sequence_output = outputs[0]
  943. prediction_scores = self.predictions(hidden_states=sequence_output, training=training)
  944. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
  945. if not return_dict:
  946. output = (prediction_scores,) + outputs[2:]
  947. return ((loss,) + output) if loss is not None else output
  948. return TFMaskedLMOutput(
  949. loss=loss,
  950. logits=prediction_scores,
  951. hidden_states=outputs.hidden_states,
  952. attentions=outputs.attentions,
  953. )
  954. def build(self, input_shape=None):
  955. if self.built:
  956. return
  957. self.built = True
  958. if getattr(self, "albert", None) is not None:
  959. with tf.name_scope(self.albert.name):
  960. self.albert.build(None)
  961. if getattr(self, "predictions", None) is not None:
  962. with tf.name_scope(self.predictions.name):
  963. self.predictions.build(None)
  964. @add_start_docstrings(
  965. """
  966. Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  967. output) e.g. for GLUE tasks.
  968. """,
  969. ALBERT_START_DOCSTRING,
  970. )
  971. class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClassificationLoss):
  972. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  973. _keys_to_ignore_on_load_unexpected = [r"predictions"]
  974. _keys_to_ignore_on_load_missing = [r"dropout"]
  975. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  976. super().__init__(config, *inputs, **kwargs)
  977. self.num_labels = config.num_labels
  978. self.albert = TFAlbertMainLayer(config, name="albert")
  979. self.dropout = keras.layers.Dropout(rate=config.classifier_dropout_prob)
  980. self.classifier = keras.layers.Dense(
  981. units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  982. )
  983. self.config = config
  984. @unpack_inputs
  985. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  986. @add_code_sample_docstrings(
  987. checkpoint="vumichien/albert-base-v2-imdb",
  988. output_type=TFSequenceClassifierOutput,
  989. config_class=_CONFIG_FOR_DOC,
  990. expected_output="'LABEL_1'",
  991. expected_loss=0.12,
  992. )
  993. def call(
  994. self,
  995. input_ids: TFModelInputType | None = None,
  996. attention_mask: np.ndarray | tf.Tensor | None = None,
  997. token_type_ids: np.ndarray | tf.Tensor | None = None,
  998. position_ids: np.ndarray | tf.Tensor | None = None,
  999. head_mask: np.ndarray | tf.Tensor | None = None,
  1000. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1001. output_attentions: Optional[bool] = None,
  1002. output_hidden_states: Optional[bool] = None,
  1003. return_dict: Optional[bool] = None,
  1004. labels: np.ndarray | tf.Tensor | None = None,
  1005. training: Optional[bool] = False,
  1006. ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
  1007. r"""
  1008. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1009. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1010. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1011. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1012. """
  1013. outputs = self.albert(
  1014. input_ids=input_ids,
  1015. attention_mask=attention_mask,
  1016. token_type_ids=token_type_ids,
  1017. position_ids=position_ids,
  1018. head_mask=head_mask,
  1019. inputs_embeds=inputs_embeds,
  1020. output_attentions=output_attentions,
  1021. output_hidden_states=output_hidden_states,
  1022. return_dict=return_dict,
  1023. training=training,
  1024. )
  1025. pooled_output = outputs[1]
  1026. pooled_output = self.dropout(inputs=pooled_output, training=training)
  1027. logits = self.classifier(inputs=pooled_output)
  1028. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  1029. if not return_dict:
  1030. output = (logits,) + outputs[2:]
  1031. return ((loss,) + output) if loss is not None else output
  1032. return TFSequenceClassifierOutput(
  1033. loss=loss,
  1034. logits=logits,
  1035. hidden_states=outputs.hidden_states,
  1036. attentions=outputs.attentions,
  1037. )
  1038. def build(self, input_shape=None):
  1039. if self.built:
  1040. return
  1041. self.built = True
  1042. if getattr(self, "albert", None) is not None:
  1043. with tf.name_scope(self.albert.name):
  1044. self.albert.build(None)
  1045. if getattr(self, "classifier", None) is not None:
  1046. with tf.name_scope(self.classifier.name):
  1047. self.classifier.build([None, None, self.config.hidden_size])
  1048. @add_start_docstrings(
  1049. """
  1050. Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1051. Named-Entity-Recognition (NER) tasks.
  1052. """,
  1053. ALBERT_START_DOCSTRING,
  1054. )
  1055. class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
  1056. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1057. _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
  1058. _keys_to_ignore_on_load_missing = [r"dropout"]
  1059. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  1060. super().__init__(config, *inputs, **kwargs)
  1061. self.num_labels = config.num_labels
  1062. self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
  1063. classifier_dropout_prob = (
  1064. config.classifier_dropout_prob
  1065. if config.classifier_dropout_prob is not None
  1066. else config.hidden_dropout_prob
  1067. )
  1068. self.dropout = keras.layers.Dropout(rate=classifier_dropout_prob)
  1069. self.classifier = keras.layers.Dense(
  1070. units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1071. )
  1072. self.config = config
  1073. @unpack_inputs
  1074. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1075. @add_code_sample_docstrings(
  1076. checkpoint=_CHECKPOINT_FOR_DOC,
  1077. output_type=TFTokenClassifierOutput,
  1078. config_class=_CONFIG_FOR_DOC,
  1079. )
  1080. def call(
  1081. self,
  1082. input_ids: TFModelInputType | None = None,
  1083. attention_mask: np.ndarray | tf.Tensor | None = None,
  1084. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1085. position_ids: np.ndarray | tf.Tensor | None = None,
  1086. head_mask: np.ndarray | tf.Tensor | None = None,
  1087. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1088. output_attentions: Optional[bool] = None,
  1089. output_hidden_states: Optional[bool] = None,
  1090. return_dict: Optional[bool] = None,
  1091. labels: np.ndarray | tf.Tensor | None = None,
  1092. training: Optional[bool] = False,
  1093. ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
  1094. r"""
  1095. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1096. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1097. """
  1098. outputs = self.albert(
  1099. input_ids=input_ids,
  1100. attention_mask=attention_mask,
  1101. token_type_ids=token_type_ids,
  1102. position_ids=position_ids,
  1103. head_mask=head_mask,
  1104. inputs_embeds=inputs_embeds,
  1105. output_attentions=output_attentions,
  1106. output_hidden_states=output_hidden_states,
  1107. return_dict=return_dict,
  1108. training=training,
  1109. )
  1110. sequence_output = outputs[0]
  1111. sequence_output = self.dropout(inputs=sequence_output, training=training)
  1112. logits = self.classifier(inputs=sequence_output)
  1113. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  1114. if not return_dict:
  1115. output = (logits,) + outputs[2:]
  1116. return ((loss,) + output) if loss is not None else output
  1117. return TFTokenClassifierOutput(
  1118. loss=loss,
  1119. logits=logits,
  1120. hidden_states=outputs.hidden_states,
  1121. attentions=outputs.attentions,
  1122. )
  1123. def build(self, input_shape=None):
  1124. if self.built:
  1125. return
  1126. self.built = True
  1127. if getattr(self, "albert", None) is not None:
  1128. with tf.name_scope(self.albert.name):
  1129. self.albert.build(None)
  1130. if getattr(self, "classifier", None) is not None:
  1131. with tf.name_scope(self.classifier.name):
  1132. self.classifier.build([None, None, self.config.hidden_size])
  1133. @add_start_docstrings(
  1134. """
  1135. Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1136. layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1137. """,
  1138. ALBERT_START_DOCSTRING,
  1139. )
  1140. class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
  1141. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1142. _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
  1143. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  1144. super().__init__(config, *inputs, **kwargs)
  1145. self.num_labels = config.num_labels
  1146. self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
  1147. self.qa_outputs = keras.layers.Dense(
  1148. units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
  1149. )
  1150. self.config = config
  1151. @unpack_inputs
  1152. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1153. @add_code_sample_docstrings(
  1154. checkpoint="vumichien/albert-base-v2-squad2",
  1155. output_type=TFQuestionAnsweringModelOutput,
  1156. config_class=_CONFIG_FOR_DOC,
  1157. qa_target_start_index=12,
  1158. qa_target_end_index=13,
  1159. expected_output="'a nice puppet'",
  1160. expected_loss=7.36,
  1161. )
  1162. def call(
  1163. self,
  1164. input_ids: TFModelInputType | None = None,
  1165. attention_mask: np.ndarray | tf.Tensor | None = None,
  1166. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1167. position_ids: np.ndarray | tf.Tensor | None = None,
  1168. head_mask: np.ndarray | tf.Tensor | None = None,
  1169. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1170. output_attentions: Optional[bool] = None,
  1171. output_hidden_states: Optional[bool] = None,
  1172. return_dict: Optional[bool] = None,
  1173. start_positions: np.ndarray | tf.Tensor | None = None,
  1174. end_positions: np.ndarray | tf.Tensor | None = None,
  1175. training: Optional[bool] = False,
  1176. ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
  1177. r"""
  1178. start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1179. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1180. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1181. are not taken into account for computing the loss.
  1182. end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1183. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1184. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1185. are not taken into account for computing the loss.
  1186. """
  1187. outputs = self.albert(
  1188. input_ids=input_ids,
  1189. attention_mask=attention_mask,
  1190. token_type_ids=token_type_ids,
  1191. position_ids=position_ids,
  1192. head_mask=head_mask,
  1193. inputs_embeds=inputs_embeds,
  1194. output_attentions=output_attentions,
  1195. output_hidden_states=output_hidden_states,
  1196. return_dict=return_dict,
  1197. training=training,
  1198. )
  1199. sequence_output = outputs[0]
  1200. logits = self.qa_outputs(inputs=sequence_output)
  1201. start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
  1202. start_logits = tf.squeeze(input=start_logits, axis=-1)
  1203. end_logits = tf.squeeze(input=end_logits, axis=-1)
  1204. loss = None
  1205. if start_positions is not None and end_positions is not None:
  1206. labels = {"start_position": start_positions}
  1207. labels["end_position"] = end_positions
  1208. loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
  1209. if not return_dict:
  1210. output = (start_logits, end_logits) + outputs[2:]
  1211. return ((loss,) + output) if loss is not None else output
  1212. return TFQuestionAnsweringModelOutput(
  1213. loss=loss,
  1214. start_logits=start_logits,
  1215. end_logits=end_logits,
  1216. hidden_states=outputs.hidden_states,
  1217. attentions=outputs.attentions,
  1218. )
  1219. def build(self, input_shape=None):
  1220. if self.built:
  1221. return
  1222. self.built = True
  1223. if getattr(self, "albert", None) is not None:
  1224. with tf.name_scope(self.albert.name):
  1225. self.albert.build(None)
  1226. if getattr(self, "qa_outputs", None) is not None:
  1227. with tf.name_scope(self.qa_outputs.name):
  1228. self.qa_outputs.build([None, None, self.config.hidden_size])
  1229. @add_start_docstrings(
  1230. """
  1231. Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1232. softmax) e.g. for RocStories/SWAG tasks.
  1233. """,
  1234. ALBERT_START_DOCSTRING,
  1235. )
  1236. class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
  1237. # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
  1238. _keys_to_ignore_on_load_unexpected = [r"pooler", r"predictions"]
  1239. _keys_to_ignore_on_load_missing = [r"dropout"]
  1240. def __init__(self, config: AlbertConfig, *inputs, **kwargs):
  1241. super().__init__(config, *inputs, **kwargs)
  1242. self.albert = TFAlbertMainLayer(config, name="albert")
  1243. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  1244. self.classifier = keras.layers.Dense(
  1245. units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  1246. )
  1247. self.config = config
  1248. @unpack_inputs
  1249. @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
  1250. @add_code_sample_docstrings(
  1251. checkpoint=_CHECKPOINT_FOR_DOC,
  1252. output_type=TFMultipleChoiceModelOutput,
  1253. config_class=_CONFIG_FOR_DOC,
  1254. )
  1255. def call(
  1256. self,
  1257. input_ids: TFModelInputType | None = None,
  1258. attention_mask: np.ndarray | tf.Tensor | None = None,
  1259. token_type_ids: np.ndarray | tf.Tensor | None = None,
  1260. position_ids: np.ndarray | tf.Tensor | None = None,
  1261. head_mask: np.ndarray | tf.Tensor | None = None,
  1262. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  1263. output_attentions: Optional[bool] = None,
  1264. output_hidden_states: Optional[bool] = None,
  1265. return_dict: Optional[bool] = None,
  1266. labels: np.ndarray | tf.Tensor | None = None,
  1267. training: Optional[bool] = False,
  1268. ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
  1269. r"""
  1270. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  1271. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  1272. where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
  1273. """
  1274. if input_ids is not None:
  1275. num_choices = shape_list(input_ids)[1]
  1276. seq_length = shape_list(input_ids)[2]
  1277. else:
  1278. num_choices = shape_list(inputs_embeds)[1]
  1279. seq_length = shape_list(inputs_embeds)[2]
  1280. flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
  1281. flat_attention_mask = (
  1282. tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
  1283. )
  1284. flat_token_type_ids = (
  1285. tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
  1286. )
  1287. flat_position_ids = (
  1288. tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
  1289. )
  1290. flat_inputs_embeds = (
  1291. tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
  1292. if inputs_embeds is not None
  1293. else None
  1294. )
  1295. outputs = self.albert(
  1296. input_ids=flat_input_ids,
  1297. attention_mask=flat_attention_mask,
  1298. token_type_ids=flat_token_type_ids,
  1299. position_ids=flat_position_ids,
  1300. head_mask=head_mask,
  1301. inputs_embeds=flat_inputs_embeds,
  1302. output_attentions=output_attentions,
  1303. output_hidden_states=output_hidden_states,
  1304. return_dict=return_dict,
  1305. training=training,
  1306. )
  1307. pooled_output = outputs[1]
  1308. pooled_output = self.dropout(inputs=pooled_output, training=training)
  1309. logits = self.classifier(inputs=pooled_output)
  1310. reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
  1311. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
  1312. if not return_dict:
  1313. output = (reshaped_logits,) + outputs[2:]
  1314. return ((loss,) + output) if loss is not None else output
  1315. return TFMultipleChoiceModelOutput(
  1316. loss=loss,
  1317. logits=reshaped_logits,
  1318. hidden_states=outputs.hidden_states,
  1319. attentions=outputs.attentions,
  1320. )
  1321. def build(self, input_shape=None):
  1322. if self.built:
  1323. return
  1324. self.built = True
  1325. if getattr(self, "albert", None) is not None:
  1326. with tf.name_scope(self.albert.name):
  1327. self.albert.build(None)
  1328. if getattr(self, "classifier", None) is not None:
  1329. with tf.name_scope(self.classifier.name):
  1330. self.classifier.build([None, None, self.config.hidden_size])
  1331. __all__ = [
  1332. "TFAlbertPreTrainedModel",
  1333. "TFAlbertModel",
  1334. "TFAlbertForPreTraining",
  1335. "TFAlbertForMaskedLM",
  1336. "TFAlbertForSequenceClassification",
  1337. "TFAlbertForTokenClassification",
  1338. "TFAlbertForQuestionAnswering",
  1339. "TFAlbertForMultipleChoice",
  1340. "TFAlbertMainLayer",
  1341. ]