| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110 |
- # coding=utf-8
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """TF 2.0 BERT model."""
- from __future__ import annotations
- import math
- import warnings
- from dataclasses import dataclass
- from typing import Dict, Optional, Tuple, Union
- import numpy as np
- import tensorflow as tf
- from ...activations_tf import get_tf_activation
- from ...modeling_tf_outputs import (
- TFBaseModelOutputWithPastAndCrossAttentions,
- TFBaseModelOutputWithPoolingAndCrossAttentions,
- TFCausalLMOutputWithCrossAttentions,
- TFMaskedLMOutput,
- TFMultipleChoiceModelOutput,
- TFNextSentencePredictorOutput,
- TFQuestionAnsweringModelOutput,
- TFSequenceClassifierOutput,
- TFTokenClassifierOutput,
- )
- from ...modeling_tf_utils import (
- TFCausalLanguageModelingLoss,
- TFMaskedLanguageModelingLoss,
- TFModelInputType,
- TFMultipleChoiceLoss,
- TFNextSentencePredictionLoss,
- TFPreTrainedModel,
- TFQuestionAnsweringLoss,
- TFSequenceClassificationLoss,
- TFTokenClassificationLoss,
- get_initializer,
- keras,
- keras_serializable,
- unpack_inputs,
- )
- from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
- from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from .configuration_bert import BertConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
- _CONFIG_FOR_DOC = "BertConfig"
- # TokenClassification docstring
- _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english"
- _TOKEN_CLASS_EXPECTED_OUTPUT = (
- "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] "
- )
- _TOKEN_CLASS_EXPECTED_LOSS = 0.01
- # QuestionAnswering docstring
- _CHECKPOINT_FOR_QA = "ydshieh/bert-base-cased-squad2"
- _QA_EXPECTED_OUTPUT = "'a nice puppet'"
- _QA_EXPECTED_LOSS = 7.41
- _QA_TARGET_START_INDEX = 14
- _QA_TARGET_END_INDEX = 15
- # SequenceClassification docstring
- _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ydshieh/bert-base-uncased-yelp-polarity"
- _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
- _SEQ_CLASS_EXPECTED_LOSS = 0.01
- class TFBertPreTrainingLoss:
- """
- Loss function suitable for BERT-like pretraining, that is, the task of pretraining a language model by combining
- NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
- computation.
- """
- def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
- loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
- # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
- unmasked_lm_losses = loss_fn(y_true=tf.nn.relu(labels["labels"]), y_pred=logits[0])
- # make sure only labels that are not equal to -100
- # are taken into account for the loss computation
- lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
- masked_lm_losses = unmasked_lm_losses * lm_loss_mask
- reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)
- # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
- unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
- ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
- masked_ns_loss = unmasked_ns_loss * ns_loss_mask
- reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask)
- return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,))
- class TFBertEmbeddings(keras.layers.Layer):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.hidden_size = config.hidden_size
- self.max_position_embeddings = config.max_position_embeddings
- self.initializer_range = config.initializer_range
- self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
- def build(self, input_shape=None):
- with tf.name_scope("word_embeddings"):
- self.weight = self.add_weight(
- name="weight",
- shape=[self.config.vocab_size, self.hidden_size],
- initializer=get_initializer(self.initializer_range),
- )
- with tf.name_scope("token_type_embeddings"):
- self.token_type_embeddings = self.add_weight(
- name="embeddings",
- shape=[self.config.type_vocab_size, self.hidden_size],
- initializer=get_initializer(self.initializer_range),
- )
- with tf.name_scope("position_embeddings"):
- self.position_embeddings = self.add_weight(
- name="embeddings",
- shape=[self.max_position_embeddings, self.hidden_size],
- initializer=get_initializer(self.initializer_range),
- )
- if self.built:
- return
- self.built = True
- if getattr(self, "LayerNorm", None) is not None:
- with tf.name_scope(self.LayerNorm.name):
- self.LayerNorm.build([None, None, self.config.hidden_size])
- def call(
- self,
- input_ids: tf.Tensor = None,
- position_ids: tf.Tensor = None,
- token_type_ids: tf.Tensor = None,
- inputs_embeds: tf.Tensor = None,
- past_key_values_length=0,
- training: bool = False,
- ) -> tf.Tensor:
- """
- Applies embedding based on inputs tensor.
- Returns:
- final_embeddings (`tf.Tensor`): output embedding tensor.
- """
- if input_ids is None and inputs_embeds is None:
- raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
- if input_ids is not None:
- check_embeddings_within_bounds(input_ids, self.config.vocab_size)
- inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
- input_shape = shape_list(inputs_embeds)[:-1]
- if token_type_ids is None:
- token_type_ids = tf.fill(dims=input_shape, value=0)
- if position_ids is None:
- position_ids = tf.expand_dims(
- tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
- )
- position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
- token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
- final_embeddings = inputs_embeds + position_embeds + token_type_embeds
- final_embeddings = self.LayerNorm(inputs=final_embeddings)
- final_embeddings = self.dropout(inputs=final_embeddings, training=training)
- return final_embeddings
- class TFBertSelfAttention(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number "
- f"of attention heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
- self.query = keras.layers.Dense(
- units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
- )
- self.key = keras.layers.Dense(
- units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
- )
- self.value = keras.layers.Dense(
- units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
- )
- self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- self.config = config
- def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
- # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
- tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
- # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
- return tf.transpose(tensor, perm=[0, 2, 1, 3])
- def call(
- self,
- hidden_states: tf.Tensor,
- attention_mask: tf.Tensor,
- head_mask: tf.Tensor,
- encoder_hidden_states: tf.Tensor,
- encoder_attention_mask: tf.Tensor,
- past_key_value: Tuple[tf.Tensor],
- output_attentions: bool,
- training: bool = False,
- ) -> Tuple[tf.Tensor]:
- batch_size = shape_list(hidden_states)[0]
- mixed_query_layer = self.query(inputs=hidden_states)
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
- if is_cross_attention and past_key_value is not None:
- # reuse k,v, cross_attentions
- key_layer = past_key_value[0]
- value_layer = past_key_value[1]
- attention_mask = encoder_attention_mask
- elif is_cross_attention:
- key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
- value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
- attention_mask = encoder_attention_mask
- elif past_key_value is not None:
- key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
- value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
- key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
- value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
- else:
- key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
- value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
- query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
- if self.is_decoder:
- # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_layer, value_layer)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- # (batch size, num_heads, seq_len_q, seq_len_k)
- attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
- dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
- attention_scores = tf.divide(attention_scores, dk)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
- attention_scores = tf.add(attention_scores, attention_mask)
- # Normalize the attention scores to probabilities.
- attention_probs = stable_softmax(logits=attention_scores, axis=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(inputs=attention_probs, training=training)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs = tf.multiply(attention_probs, head_mask)
- attention_output = tf.matmul(attention_probs, value_layer)
- attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
- # (batch_size, seq_len_q, all_head_size)
- attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
- outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
- if self.is_decoder:
- outputs = outputs + (past_key_value,)
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "query", None) is not None:
- with tf.name_scope(self.query.name):
- self.query.build([None, None, self.config.hidden_size])
- if getattr(self, "key", None) is not None:
- with tf.name_scope(self.key.name):
- self.key.build([None, None, self.config.hidden_size])
- if getattr(self, "value", None) is not None:
- with tf.name_scope(self.value.name):
- self.value.build([None, None, self.config.hidden_size])
- class TFBertSelfOutput(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.dense = keras.layers.Dense(
- units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
- self.config = config
- def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
- hidden_states = self.dense(inputs=hidden_states)
- hidden_states = self.dropout(inputs=hidden_states, training=training)
- hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "dense", None) is not None:
- with tf.name_scope(self.dense.name):
- self.dense.build([None, None, self.config.hidden_size])
- if getattr(self, "LayerNorm", None) is not None:
- with tf.name_scope(self.LayerNorm.name):
- self.LayerNorm.build([None, None, self.config.hidden_size])
- class TFBertAttention(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.self_attention = TFBertSelfAttention(config, name="self")
- self.dense_output = TFBertSelfOutput(config, name="output")
- def prune_heads(self, heads):
- raise NotImplementedError
- def call(
- self,
- input_tensor: tf.Tensor,
- attention_mask: tf.Tensor,
- head_mask: tf.Tensor,
- encoder_hidden_states: tf.Tensor,
- encoder_attention_mask: tf.Tensor,
- past_key_value: Tuple[tf.Tensor],
- output_attentions: bool,
- training: bool = False,
- ) -> Tuple[tf.Tensor]:
- self_outputs = self.self_attention(
- hidden_states=input_tensor,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- training=training,
- )
- attention_output = self.dense_output(
- hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
- )
- # add attentions (possibly with past_key_value) if we output them
- outputs = (attention_output,) + self_outputs[1:]
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "self_attention", None) is not None:
- with tf.name_scope(self.self_attention.name):
- self.self_attention.build(None)
- if getattr(self, "dense_output", None) is not None:
- with tf.name_scope(self.dense_output.name):
- self.dense_output.build(None)
- class TFBertIntermediate(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.dense = keras.layers.Dense(
- units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = get_tf_activation(config.hidden_act)
- else:
- self.intermediate_act_fn = config.hidden_act
- self.config = config
- def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
- hidden_states = self.dense(inputs=hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "dense", None) is not None:
- with tf.name_scope(self.dense.name):
- self.dense.build([None, None, self.config.hidden_size])
- class TFBertOutput(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.dense = keras.layers.Dense(
- units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
- self.config = config
- def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
- hidden_states = self.dense(inputs=hidden_states)
- hidden_states = self.dropout(inputs=hidden_states, training=training)
- hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "dense", None) is not None:
- with tf.name_scope(self.dense.name):
- self.dense.build([None, None, self.config.intermediate_size])
- if getattr(self, "LayerNorm", None) is not None:
- with tf.name_scope(self.LayerNorm.name):
- self.LayerNorm.build([None, None, self.config.hidden_size])
- class TFBertLayer(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.attention = TFBertAttention(config, name="attention")
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = TFBertAttention(config, name="crossattention")
- self.intermediate = TFBertIntermediate(config, name="intermediate")
- self.bert_output = TFBertOutput(config, name="output")
- def call(
- self,
- hidden_states: tf.Tensor,
- attention_mask: tf.Tensor,
- head_mask: tf.Tensor,
- encoder_hidden_states: tf.Tensor | None,
- encoder_attention_mask: tf.Tensor | None,
- past_key_value: Tuple[tf.Tensor] | None,
- output_attentions: bool,
- training: bool = False,
- ) -> Tuple[tf.Tensor]:
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
- self_attention_outputs = self.attention(
- input_tensor=hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=self_attn_past_key_value,
- output_attentions=output_attentions,
- training=training,
- )
- attention_output = self_attention_outputs[0]
- # if decoder, the last output is tuple of self-attn cache
- if self.is_decoder:
- outputs = self_attention_outputs[1:-1]
- present_key_value = self_attention_outputs[-1]
- else:
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- cross_attn_present_key_value = None
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
- " by setting `config.add_cross_attention=True`"
- )
- # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
- cross_attention_outputs = self.crossattention(
- input_tensor=attention_output,
- attention_mask=attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_value=cross_attn_past_key_value,
- output_attentions=output_attentions,
- training=training,
- )
- attention_output = cross_attention_outputs[0]
- outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
- # add cross-attn cache to positions 3,4 of present_key_value tuple
- cross_attn_present_key_value = cross_attention_outputs[-1]
- present_key_value = present_key_value + cross_attn_present_key_value
- intermediate_output = self.intermediate(hidden_states=attention_output)
- layer_output = self.bert_output(
- hidden_states=intermediate_output, input_tensor=attention_output, training=training
- )
- outputs = (layer_output,) + outputs # add attentions if we output them
- # if decoder, return the attn key/values as the last output
- if self.is_decoder:
- outputs = outputs + (present_key_value,)
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "attention", None) is not None:
- with tf.name_scope(self.attention.name):
- self.attention.build(None)
- if getattr(self, "intermediate", None) is not None:
- with tf.name_scope(self.intermediate.name):
- self.intermediate.build(None)
- if getattr(self, "bert_output", None) is not None:
- with tf.name_scope(self.bert_output.name):
- self.bert_output.build(None)
- if getattr(self, "crossattention", None) is not None:
- with tf.name_scope(self.crossattention.name):
- self.crossattention.build(None)
- class TFBertEncoder(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.layer = [TFBertLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
- def call(
- self,
- hidden_states: tf.Tensor,
- attention_mask: tf.Tensor,
- head_mask: tf.Tensor,
- encoder_hidden_states: tf.Tensor | None,
- encoder_attention_mask: tf.Tensor | None,
- past_key_values: Tuple[Tuple[tf.Tensor]] | None,
- use_cache: Optional[bool],
- output_attentions: bool,
- output_hidden_states: bool,
- return_dict: bool,
- training: bool = False,
- ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
- next_decoder_cache = () if use_cache else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- past_key_value = past_key_values[i] if past_key_values is not None else None
- layer_outputs = layer_module(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- training=training,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache += (layer_outputs[-1],)
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- if self.config.add_cross_attention and encoder_hidden_states is not None:
- all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
- )
- return TFBaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_decoder_cache,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "layer", None) is not None:
- for layer in self.layer:
- with tf.name_scope(layer.name):
- layer.build(None)
- class TFBertPooler(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.dense = keras.layers.Dense(
- units=config.hidden_size,
- kernel_initializer=get_initializer(config.initializer_range),
- activation="tanh",
- name="dense",
- )
- self.config = config
- def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(inputs=first_token_tensor)
- return pooled_output
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "dense", None) is not None:
- with tf.name_scope(self.dense.name):
- self.dense.build([None, None, self.config.hidden_size])
- class TFBertPredictionHeadTransform(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.dense = keras.layers.Dense(
- units=config.hidden_size,
- kernel_initializer=get_initializer(config.initializer_range),
- name="dense",
- )
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = get_tf_activation(config.hidden_act)
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.config = config
- def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
- hidden_states = self.dense(inputs=hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(inputs=hidden_states)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "dense", None) is not None:
- with tf.name_scope(self.dense.name):
- self.dense.build([None, None, self.config.hidden_size])
- if getattr(self, "LayerNorm", None) is not None:
- with tf.name_scope(self.LayerNorm.name):
- self.LayerNorm.build([None, None, self.config.hidden_size])
- class TFBertLMPredictionHead(keras.layers.Layer):
- def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.hidden_size = config.hidden_size
- self.transform = TFBertPredictionHeadTransform(config, name="transform")
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.input_embeddings = input_embeddings
- def build(self, input_shape=None):
- self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
- if self.built:
- return
- self.built = True
- if getattr(self, "transform", None) is not None:
- with tf.name_scope(self.transform.name):
- self.transform.build(None)
- def get_output_embeddings(self) -> keras.layers.Layer:
- return self.input_embeddings
- def set_output_embeddings(self, value: tf.Variable):
- self.input_embeddings.weight = value
- self.input_embeddings.vocab_size = shape_list(value)[0]
- def get_bias(self) -> Dict[str, tf.Variable]:
- return {"bias": self.bias}
- def set_bias(self, value: tf.Variable):
- self.bias = value["bias"]
- self.config.vocab_size = shape_list(value["bias"])[0]
- def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
- hidden_states = self.transform(hidden_states=hidden_states)
- seq_length = shape_list(hidden_states)[1]
- hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
- hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
- hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
- hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
- return hidden_states
- class TFBertMLMHead(keras.layers.Layer):
- def __init__(self, config: BertConfig, input_embeddings: keras.layers.Layer, **kwargs):
- super().__init__(**kwargs)
- self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
- def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
- prediction_scores = self.predictions(hidden_states=sequence_output)
- return prediction_scores
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "predictions", None) is not None:
- with tf.name_scope(self.predictions.name):
- self.predictions.build(None)
- class TFBertNSPHead(keras.layers.Layer):
- def __init__(self, config: BertConfig, **kwargs):
- super().__init__(**kwargs)
- self.seq_relationship = keras.layers.Dense(
- units=2,
- kernel_initializer=get_initializer(config.initializer_range),
- name="seq_relationship",
- )
- self.config = config
- def call(self, pooled_output: tf.Tensor) -> tf.Tensor:
- seq_relationship_score = self.seq_relationship(inputs=pooled_output)
- return seq_relationship_score
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "seq_relationship", None) is not None:
- with tf.name_scope(self.seq_relationship.name):
- self.seq_relationship.build([None, None, self.config.hidden_size])
- @keras_serializable
- class TFBertMainLayer(keras.layers.Layer):
- config_class = BertConfig
- def __init__(self, config: BertConfig, add_pooling_layer: bool = True, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.is_decoder = config.is_decoder
- self.embeddings = TFBertEmbeddings(config, name="embeddings")
- self.encoder = TFBertEncoder(config, name="encoder")
- self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
- def get_input_embeddings(self) -> keras.layers.Layer:
- return self.embeddings
- def set_input_embeddings(self, value: tf.Variable):
- self.embeddings.weight = value
- self.embeddings.vocab_size = shape_list(value)[0]
- def _prune_heads(self, heads_to_prune):
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- raise NotImplementedError
- @unpack_inputs
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
- encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
- past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- training: bool = False,
- ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
- if not self.config.is_decoder:
- use_cache = False
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = shape_list(input_ids)
- elif inputs_embeds is not None:
- input_shape = shape_list(inputs_embeds)[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- batch_size, seq_length = input_shape
- if past_key_values is None:
- past_key_values_length = 0
- past_key_values = [None] * len(self.encoder.layer)
- else:
- past_key_values_length = shape_list(past_key_values[0][0])[-2]
- if attention_mask is None:
- attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
- if token_type_ids is None:
- token_type_ids = tf.fill(dims=input_shape, value=0)
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- past_key_values_length=past_key_values_length,
- training=training,
- )
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask_shape = shape_list(attention_mask)
- mask_seq_length = seq_length + past_key_values_length
- # Copied from `modeling_tf_t5.py`
- # Provided a padding mask of dimensions [batch_size, mask_seq_length]
- # - if the model is a decoder, apply a causal mask in addition to the padding mask
- # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
- if self.is_decoder:
- seq_ids = tf.range(mask_seq_length)
- causal_mask = tf.less_equal(
- tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
- seq_ids[None, :, None],
- )
- causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
- extended_attention_mask = causal_mask * attention_mask[:, None, :]
- attention_mask_shape = shape_list(extended_attention_mask)
- extended_attention_mask = tf.reshape(
- extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
- )
- if past_key_values[0] is not None:
- # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
- extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
- else:
- extended_attention_mask = tf.reshape(
- attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
- )
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and -10000.0 for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
- one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
- ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
- extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
- # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
- if self.is_decoder and encoder_attention_mask is not None:
- # If a 2D ou 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
- num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
- if num_dims_encoder_attention_mask == 3:
- encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
- if num_dims_encoder_attention_mask == 2:
- encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
- # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
- # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
- # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
- # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
- encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
- else:
- encoder_extended_attention_mask = None
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- if head_mask is not None:
- raise NotImplementedError
- else:
- head_mask = [None] * self.config.num_hidden_layers
- encoder_outputs = self.encoder(
- hidden_states=embedding_output,
- attention_mask=extended_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
- if not return_dict:
- return (
- sequence_output,
- pooled_output,
- ) + encoder_outputs[1:]
- return TFBaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- past_key_values=encoder_outputs.past_key_values,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- cross_attentions=encoder_outputs.cross_attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "embeddings", None) is not None:
- with tf.name_scope(self.embeddings.name):
- self.embeddings.build(None)
- if getattr(self, "encoder", None) is not None:
- with tf.name_scope(self.encoder.name):
- self.encoder.build(None)
- if getattr(self, "pooler", None) is not None:
- with tf.name_scope(self.pooler.name):
- self.pooler.build(None)
- class TFBertPreTrainedModel(TFPreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = BertConfig
- base_model_prefix = "bert"
- @dataclass
- class TFBertForPreTrainingOutput(ModelOutput):
- """
- Output type of [`TFBertForPreTraining`].
- Args:
- prediction_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- seq_relationship_logits (`tf.Tensor` of shape `(batch_size, 2)`):
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
- before SoftMax).
- hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
- `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- loss: tf.Tensor | None = None
- prediction_logits: tf.Tensor = None
- seq_relationship_logits: tf.Tensor = None
- hidden_states: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None
- attentions: Optional[Union[Tuple[tf.Tensor], tf.Tensor]] = None
- BERT_START_DOCSTRING = r"""
- This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
- as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
- behavior.
- <Tip>
- TensorFlow models and layers in `transformers` accept two formats as input:
- - having all inputs as keyword arguments (like PyTorch models), or
- - having all inputs as a list, tuple or dict in the first positional argument.
- The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
- and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
- pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
- format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
- the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
- positional argument:
- - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
- `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- - a dictionary with one or several input Tensors associated to the input names given in the docstring:
- `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
- Note that when creating models and layers with
- [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
- about any of this, as you can just pass inputs like you would to any other Python function!
- </Tip>
- Args:
- config ([`BertConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
- """
- BERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
- [`PreTrainedTokenizer.encode`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
- config will be used instead.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
- used instead.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
- eager mode, in graph mode the value will always be set to True.
- training (`bool`, *optional*, defaults to `False``):
- Whether or not to use the model in training mode (some modules like dropout modules have different
- behaviors between training and evaluation).
- """
- @add_start_docstrings(
- "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
- BERT_START_DOCSTRING,
- )
- class TFBertModel(TFBertPreTrainedModel):
- def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
- encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
- past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- training: Optional[bool] = False,
- ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
- r"""
- encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
- the model is configured as a decoder.
- encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
- the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
- contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- use_cache (`bool`, *optional*, defaults to `True`):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`). Set to `False` during training, `True` during generation
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- @add_start_docstrings(
- """
- Bert Model with two heads on top as done during the pretraining:
- a `masked language modeling` head and a `next sentence prediction (classification)` head.
- """,
- BERT_START_DOCSTRING,
- )
- class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [
- r"position_ids",
- r"cls.predictions.decoder.weight",
- r"cls.predictions.decoder.bias",
- ]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.nsp = TFBertNSPHead(config, name="nsp___cls")
- self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
- def get_lm_head(self) -> keras.layers.Layer:
- return self.mlm.predictions
- def get_prefix_bias_name(self) -> str:
- warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
- return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- next_sentence_label: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFBertForPreTrainingOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- next_sentence_label (`tf.Tensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
- (see `input_ids` docstring) Indices should be in `[0, 1]`:
- - 0 indicates sequence B is a continuation of sequence A,
- - 1 indicates sequence B is a random sequence.
- kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
- Used to hide legacy arguments that have been deprecated.
- Return:
- Examples:
- ```python
- >>> import tensorflow as tf
- >>> from transformers import AutoTokenizer, TFBertForPreTraining
- >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
- >>> model = TFBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
- >>> input_ids = tokenizer("Hello, my dog is cute", add_special_tokens=True, return_tensors="tf")
- >>> # Batch size 1
- >>> outputs = model(input_ids)
- >>> prediction_logits, seq_relationship_logits = outputs[:2]
- ```"""
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output, pooled_output = outputs[:2]
- prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
- seq_relationship_score = self.nsp(pooled_output=pooled_output)
- total_loss = None
- if labels is not None and next_sentence_label is not None:
- d_labels = {"labels": labels}
- d_labels["next_sentence_label"] = next_sentence_label
- total_loss = self.hf_compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))
- if not return_dict:
- output = (prediction_scores, seq_relationship_score) + outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
- return TFBertForPreTrainingOutput(
- loss=total_loss,
- prediction_logits=prediction_scores,
- seq_relationship_logits=seq_relationship_score,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "nsp", None) is not None:
- with tf.name_scope(self.nsp.name):
- self.nsp.build(None)
- if getattr(self, "mlm", None) is not None:
- with tf.name_scope(self.mlm.name):
- self.mlm.build(None)
- @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
- class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [
- r"pooler",
- r"cls.seq_relationship",
- r"cls.predictions.decoder.weight",
- r"nsp___cls",
- ]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- if config.is_decoder:
- logger.warning(
- "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
- "bi-directional self-attention."
- )
- self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
- self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
- def get_lm_head(self) -> keras.layers.Layer:
- return self.mlm.predictions
- def get_prefix_bias_name(self) -> str:
- warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
- return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFMaskedLMOutput,
- config_class=_CONFIG_FOR_DOC,
- expected_output="'paris'",
- expected_loss=0.88,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = outputs[0]
- prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
- loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)
- if not return_dict:
- output = (prediction_scores,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFMaskedLMOutput(
- loss=loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "mlm", None) is not None:
- with tf.name_scope(self.mlm.name):
- self.mlm.build(None)
- class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [
- r"pooler",
- r"cls.seq_relationship",
- r"cls.predictions.decoder.weight",
- r"nsp___cls",
- ]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- if not config.is_decoder:
- logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
- self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
- self.mlm = TFBertMLMHead(config, input_embeddings=self.bert.embeddings, name="mlm___cls")
- def get_lm_head(self) -> keras.layers.Layer:
- return self.mlm.predictions
- def get_prefix_bias_name(self) -> str:
- warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
- return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
- input_shape = input_ids.shape
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
- if attention_mask is None:
- attention_mask = tf.ones(input_shape)
- # cut decoder_input_ids if past is used
- if past_key_values is not None:
- input_ids = input_ids[:, -1:]
- return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
- @unpack_inputs
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFCausalLMOutputWithCrossAttentions,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
- encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
- past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- **kwargs,
- ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
- r"""
- encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
- the model is configured as a decoder.
- encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
- the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
- contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- use_cache (`bool`, *optional*, defaults to `True`):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`). Set to `False` during training, `True` during generation
- labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
- config.vocab_size - 1]`.
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = outputs[0]
- logits = self.mlm(sequence_output=sequence_output, training=training)
- loss = None
- if labels is not None:
- # shift labels to the left and cut last logit token
- shifted_logits = logits[:, :-1]
- labels = labels[:, 1:]
- loss = self.hf_compute_loss(labels=labels, logits=shifted_logits)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFCausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "mlm", None) is not None:
- with tf.name_scope(self.mlm.name):
- self.mlm.build(None)
- @add_start_docstrings(
- """Bert Model with a `next sentence prediction (classification)` head on top.""",
- BERT_START_DOCSTRING,
- )
- class TFBertForNextSentencePrediction(TFBertPreTrainedModel, TFNextSentencePredictionLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"cls.predictions"]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.nsp = TFBertNSPHead(config, name="nsp___cls")
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- next_sentence_label: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFNextSentencePredictorOutput, Tuple[tf.Tensor]]:
- r"""
- Return:
- Examples:
- ```python
- >>> import tensorflow as tf
- >>> from transformers import AutoTokenizer, TFBertForNextSentencePrediction
- >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
- >>> model = TFBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="tf")
- >>> logits = model(encoding["input_ids"], token_type_ids=encoding["token_type_ids"])[0]
- >>> assert logits[0][0] < logits[0][1] # the next sentence was random
- ```"""
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- pooled_output = outputs[1]
- seq_relationship_scores = self.nsp(pooled_output=pooled_output)
- next_sentence_loss = (
- None
- if next_sentence_label is None
- else self.hf_compute_loss(labels=next_sentence_label, logits=seq_relationship_scores)
- )
- if not return_dict:
- output = (seq_relationship_scores,) + outputs[2:]
- return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
- return TFNextSentencePredictorOutput(
- loss=next_sentence_loss,
- logits=seq_relationship_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "nsp", None) is not None:
- with tf.name_scope(self.nsp.name):
- self.nsp.build(None)
- @add_start_docstrings(
- """
- Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
- output) e.g. for GLUE tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
- _keys_to_ignore_on_load_missing = [r"dropout"]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, name="bert")
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = keras.layers.Dropout(rate=classifier_dropout)
- self.classifier = keras.layers.Dense(
- units=config.num_labels,
- kernel_initializer=get_initializer(config.initializer_range),
- name="classifier",
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
- output_type=TFSequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
- expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(inputs=pooled_output, training=training)
- logits = self.classifier(inputs=pooled_output)
- loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.hidden_size])
- @add_start_docstrings(
- """
- Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
- softmax) e.g. for RocStories/SWAG tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [r"mlm___cls", r"nsp___cls", r"cls.predictions", r"cls.seq_relationship"]
- _keys_to_ignore_on_load_missing = [r"dropout"]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
- self.classifier = keras.layers.Dense(
- units=1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFMultipleChoiceModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
- where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
- """
- if input_ids is not None:
- num_choices = shape_list(input_ids)[1]
- seq_length = shape_list(input_ids)[2]
- else:
- num_choices = shape_list(inputs_embeds)[1]
- seq_length = shape_list(inputs_embeds)[2]
- flat_input_ids = tf.reshape(tensor=input_ids, shape=(-1, seq_length)) if input_ids is not None else None
- flat_attention_mask = (
- tf.reshape(tensor=attention_mask, shape=(-1, seq_length)) if attention_mask is not None else None
- )
- flat_token_type_ids = (
- tf.reshape(tensor=token_type_ids, shape=(-1, seq_length)) if token_type_ids is not None else None
- )
- flat_position_ids = (
- tf.reshape(tensor=position_ids, shape=(-1, seq_length)) if position_ids is not None else None
- )
- flat_inputs_embeds = (
- tf.reshape(tensor=inputs_embeds, shape=(-1, seq_length, shape_list(inputs_embeds)[3]))
- if inputs_embeds is not None
- else None
- )
- outputs = self.bert(
- input_ids=flat_input_ids,
- attention_mask=flat_attention_mask,
- token_type_ids=flat_token_type_ids,
- position_ids=flat_position_ids,
- head_mask=head_mask,
- inputs_embeds=flat_inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(inputs=pooled_output, training=training)
- logits = self.classifier(inputs=pooled_output)
- reshaped_logits = tf.reshape(tensor=logits, shape=(-1, num_choices))
- loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=reshaped_logits)
- if not return_dict:
- output = (reshaped_logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFMultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.hidden_size])
- @add_start_docstrings(
- """
- Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
- Named-Entity-Recognition (NER) tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [
- r"pooler",
- r"mlm___cls",
- r"nsp___cls",
- r"cls.predictions",
- r"cls.seq_relationship",
- ]
- _keys_to_ignore_on_load_missing = [r"dropout"]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = keras.layers.Dropout(rate=classifier_dropout)
- self.classifier = keras.layers.Dense(
- units=config.num_labels,
- kernel_initializer=get_initializer(config.initializer_range),
- name="classifier",
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
- output_type=TFTokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
- expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(inputs=sequence_output, training=training)
- logits = self.classifier(inputs=sequence_output)
- loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFTokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.hidden_size])
- @add_start_docstrings(
- """
- Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
- layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- BERT_START_DOCSTRING,
- )
- class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
- # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
- _keys_to_ignore_on_load_unexpected = [
- r"pooler",
- r"mlm___cls",
- r"nsp___cls",
- r"cls.predictions",
- r"cls.seq_relationship",
- ]
- def __init__(self, config: BertConfig, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
- self.qa_outputs = keras.layers.Dense(
- units=config.num_labels,
- kernel_initializer=get_initializer(config.initializer_range),
- name="qa_outputs",
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_QA,
- output_type=TFQuestionAnsweringModelOutput,
- config_class=_CONFIG_FOR_DOC,
- qa_target_start_index=_QA_TARGET_START_INDEX,
- qa_target_end_index=_QA_TARGET_END_INDEX,
- expected_output=_QA_EXPECTED_OUTPUT,
- expected_loss=_QA_EXPECTED_LOSS,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- token_type_ids: np.ndarray | tf.Tensor | None = None,
- position_ids: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- start_positions: np.ndarray | tf.Tensor | None = None,
- end_positions: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
- r"""
- start_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(inputs=sequence_output)
- start_logits, end_logits = tf.split(value=logits, num_or_size_splits=2, axis=-1)
- start_logits = tf.squeeze(input=start_logits, axis=-1)
- end_logits = tf.squeeze(input=end_logits, axis=-1)
- loss = None
- if start_positions is not None and end_positions is not None:
- labels = {"start_position": start_positions}
- labels["end_position"] = end_positions
- loss = self.hf_compute_loss(labels=labels, logits=(start_logits, end_logits))
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TFQuestionAnsweringModelOutput(
- loss=loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "bert", None) is not None:
- with tf.name_scope(self.bert.name):
- self.bert.build(None)
- if getattr(self, "qa_outputs", None) is not None:
- with tf.name_scope(self.qa_outputs.name):
- self.qa_outputs.build([None, None, self.config.hidden_size])
|