| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135 |
- # coding=utf-8
- # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- TF 2.0 DistilBERT model
- """
- from __future__ import annotations
- import warnings
- from typing import Optional, Tuple, Union
- import numpy as np
- import tensorflow as tf
- from ...activations_tf import get_tf_activation
- from ...modeling_tf_outputs import (
- TFBaseModelOutput,
- TFMaskedLMOutput,
- TFMultipleChoiceModelOutput,
- TFQuestionAnsweringModelOutput,
- TFSequenceClassifierOutput,
- TFTokenClassifierOutput,
- )
- from ...modeling_tf_utils import (
- TFMaskedLanguageModelingLoss,
- TFModelInputType,
- TFMultipleChoiceLoss,
- TFPreTrainedModel,
- TFQuestionAnsweringLoss,
- TFSequenceClassificationLoss,
- TFTokenClassificationLoss,
- get_initializer,
- keras,
- keras_serializable,
- unpack_inputs,
- )
- from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
- from ...utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- )
- from .configuration_distilbert import DistilBertConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
- _CONFIG_FOR_DOC = "DistilBertConfig"
- class TFEmbeddings(keras.layers.Layer):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.dim = config.dim
- self.initializer_range = config.initializer_range
- self.max_position_embeddings = config.max_position_embeddings
- self.LayerNorm = keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
- self.dropout = keras.layers.Dropout(rate=config.dropout)
- def build(self, input_shape=None):
- with tf.name_scope("word_embeddings"):
- self.weight = self.add_weight(
- name="weight",
- shape=[self.config.vocab_size, self.dim],
- initializer=get_initializer(initializer_range=self.initializer_range),
- )
- with tf.name_scope("position_embeddings"):
- self.position_embeddings = self.add_weight(
- name="embeddings",
- shape=[self.max_position_embeddings, self.dim],
- initializer=get_initializer(initializer_range=self.initializer_range),
- )
- if self.built:
- return
- self.built = True
- if getattr(self, "LayerNorm", None) is not None:
- with tf.name_scope(self.LayerNorm.name):
- self.LayerNorm.build([None, None, self.config.dim])
- def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
- """
- Applies embedding based on inputs tensor.
- Returns:
- final_embeddings (`tf.Tensor`): output embedding tensor.
- """
- assert not (input_ids is None and inputs_embeds is None)
- if input_ids is not None:
- check_embeddings_within_bounds(input_ids, self.config.vocab_size)
- inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
- input_shape = shape_list(inputs_embeds)[:-1]
- if position_ids is None:
- position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
- position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
- final_embeddings = inputs_embeds + position_embeds
- final_embeddings = self.LayerNorm(inputs=final_embeddings)
- final_embeddings = self.dropout(inputs=final_embeddings, training=training)
- return final_embeddings
- class TFMultiHeadSelfAttention(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.n_heads = config.n_heads
- self.dim = config.dim
- self.dropout = keras.layers.Dropout(config.attention_dropout)
- self.output_attentions = config.output_attentions
- assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
- self.q_lin = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin"
- )
- self.k_lin = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="k_lin"
- )
- self.v_lin = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="v_lin"
- )
- self.out_lin = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="out_lin"
- )
- self.pruned_heads = set()
- self.config = config
- def prune_heads(self, heads):
- raise NotImplementedError
- def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
- """
- Parameters:
- query: tf.Tensor(bs, seq_length, dim)
- key: tf.Tensor(bs, seq_length, dim)
- value: tf.Tensor(bs, seq_length, dim)
- mask: tf.Tensor(bs, seq_length)
- Returns:
- weights: tf.Tensor(bs, n_heads, seq_length, seq_length) Attention weights context: tf.Tensor(bs,
- seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
- """
- bs, q_length, dim = shape_list(query)
- k_length = shape_list(key)[1]
- # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
- # assert key.size() == value.size()
- dim_per_head = int(self.dim / self.n_heads)
- dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
- mask_reshape = [bs, 1, 1, k_length]
- def shape(x):
- """separate heads"""
- return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
- def unshape(x):
- """group heads"""
- return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
- q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
- k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
- v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
- q = tf.cast(q, dtype=tf.float32)
- q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
- k = tf.cast(k, dtype=q.dtype)
- scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
- mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
- # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
- mask = tf.cast(mask, dtype=scores.dtype)
- scores = scores - 1e30 * (1.0 - mask)
- weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
- weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
- # Mask heads if we want to
- if head_mask is not None:
- weights = weights * head_mask
- context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
- context = unshape(context) # (bs, q_length, dim)
- context = self.out_lin(context) # (bs, q_length, dim)
- if output_attentions:
- return (context, weights)
- else:
- return (context,)
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "q_lin", None) is not None:
- with tf.name_scope(self.q_lin.name):
- self.q_lin.build([None, None, self.config.dim])
- if getattr(self, "k_lin", None) is not None:
- with tf.name_scope(self.k_lin.name):
- self.k_lin.build([None, None, self.config.dim])
- if getattr(self, "v_lin", None) is not None:
- with tf.name_scope(self.v_lin.name):
- self.v_lin.build([None, None, self.config.dim])
- if getattr(self, "out_lin", None) is not None:
- with tf.name_scope(self.out_lin.name):
- self.out_lin.build([None, None, self.config.dim])
- class TFFFN(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dropout = keras.layers.Dropout(config.dropout)
- self.lin1 = keras.layers.Dense(
- config.hidden_dim, kernel_initializer=get_initializer(config.initializer_range), name="lin1"
- )
- self.lin2 = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
- )
- self.activation = get_tf_activation(config.activation)
- self.config = config
- def call(self, input, training=False):
- x = self.lin1(input)
- x = self.activation(x)
- x = self.lin2(x)
- x = self.dropout(x, training=training)
- return x
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "lin1", None) is not None:
- with tf.name_scope(self.lin1.name):
- self.lin1.build([None, None, self.config.dim])
- if getattr(self, "lin2", None) is not None:
- with tf.name_scope(self.lin2.name):
- self.lin2.build([None, None, self.config.hidden_dim])
- class TFTransformerBlock(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.n_heads = config.n_heads
- self.dim = config.dim
- self.hidden_dim = config.hidden_dim
- self.dropout = keras.layers.Dropout(config.dropout)
- self.activation = config.activation
- self.output_attentions = config.output_attentions
- assert (
- config.dim % config.n_heads == 0
- ), f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}"
- self.attention = TFMultiHeadSelfAttention(config, name="attention")
- self.sa_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm")
- self.ffn = TFFFN(config, name="ffn")
- self.output_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
- self.config = config
- def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
- """
- Parameters:
- x: tf.Tensor(bs, seq_length, dim)
- attn_mask: tf.Tensor(bs, seq_length)
- Outputs: sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
- tf.Tensor(bs, seq_length, dim) The output of the transformer block contextualization.
- """
- # Self-Attention
- sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
- if output_attentions:
- sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
- else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
- # assert type(sa_output) == tuple
- sa_output = sa_output[0]
- sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
- # Feed Forward Network
- ffn_output = self.ffn(sa_output, training=training) # (bs, seq_length, dim)
- ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
- output = (ffn_output,)
- if output_attentions:
- output = (sa_weights,) + output
- return output
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "attention", None) is not None:
- with tf.name_scope(self.attention.name):
- self.attention.build(None)
- if getattr(self, "sa_layer_norm", None) is not None:
- with tf.name_scope(self.sa_layer_norm.name):
- self.sa_layer_norm.build([None, None, self.config.dim])
- if getattr(self, "ffn", None) is not None:
- with tf.name_scope(self.ffn.name):
- self.ffn.build(None)
- if getattr(self, "output_layer_norm", None) is not None:
- with tf.name_scope(self.output_layer_norm.name):
- self.output_layer_norm.build([None, None, self.config.dim])
- class TFTransformer(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.n_layers = config.n_layers
- self.output_hidden_states = config.output_hidden_states
- self.output_attentions = config.output_attentions
- self.layer = [TFTransformerBlock(config, name=f"layer_._{i}") for i in range(config.n_layers)]
- def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):
- # docstyle-ignore
- """
- Parameters:
- x: tf.Tensor(bs, seq_length, dim) Input sequence embedded.
- attn_mask: tf.Tensor(bs, seq_length) Attention mask on the sequence.
- Returns:
- hidden_state: tf.Tensor(bs, seq_length, dim)
- Sequence of hidden states in the last (top) layer
- all_hidden_states: Tuple[tf.Tensor(bs, seq_length, dim)]
- Tuple of length n_layers with the hidden states from each layer.
- Optional: only if output_hidden_states=True
- all_attentions: Tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)]
- Tuple of length n_layers with the attention weights from each layer
- Optional: only if output_attentions=True
- """
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- hidden_state = x
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_state,)
- layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
- hidden_state = layer_outputs[-1]
- if output_attentions:
- assert len(layer_outputs) == 2
- attentions = layer_outputs[0]
- all_attentions = all_attentions + (attentions,)
- else:
- assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_state,)
- if not return_dict:
- return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
- return TFBaseModelOutput(
- last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "layer", None) is not None:
- for layer in self.layer:
- with tf.name_scope(layer.name):
- layer.build(None)
- @keras_serializable
- class TFDistilBertMainLayer(keras.layers.Layer):
- config_class = DistilBertConfig
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.num_hidden_layers = config.num_hidden_layers
- self.output_attentions = config.output_attentions
- self.output_hidden_states = config.output_hidden_states
- self.return_dict = config.use_return_dict
- self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
- self.transformer = TFTransformer(config, name="transformer") # Encoder
- def get_input_embeddings(self):
- return self.embeddings
- def set_input_embeddings(self, value):
- self.embeddings.weight = value
- self.embeddings.vocab_size = value.shape[0]
- def _prune_heads(self, heads_to_prune):
- raise NotImplementedError
- @unpack_inputs
- def call(
- self,
- input_ids=None,
- attention_mask=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- training=False,
- ):
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = shape_list(input_ids)
- elif inputs_embeds is not None:
- input_shape = shape_list(inputs_embeds)[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if attention_mask is None:
- attention_mask = tf.ones(input_shape) # (bs, seq_length)
- attention_mask = tf.cast(attention_mask, dtype=tf.float32)
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- if head_mask is not None:
- raise NotImplementedError
- else:
- head_mask = [None] * self.num_hidden_layers
- embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
- tfmr_output = self.transformer(
- embedding_output,
- attention_mask,
- head_mask,
- output_attentions,
- output_hidden_states,
- return_dict,
- training=training,
- )
- return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "embeddings", None) is not None:
- with tf.name_scope(self.embeddings.name):
- self.embeddings.build(None)
- if getattr(self, "transformer", None) is not None:
- with tf.name_scope(self.transformer.name):
- self.transformer.build(None)
- # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
- class TFDistilBertPreTrainedModel(TFPreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = DistilBertConfig
- base_model_prefix = "distilbert"
- DISTILBERT_START_DOCSTRING = r"""
- This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
- as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
- behavior.
- <Tip>
- TensorFlow models and layers in `transformers` accept two formats as input:
- - having all inputs as keyword arguments (like PyTorch models), or
- - having all inputs as a list, tuple or dict in the first positional argument.
- The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
- and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
- pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
- format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
- the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
- positional argument:
- - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
- `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- - a dictionary with one or several input Tensors associated to the input names given in the docstring:
- `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
- Note that when creating models and layers with
- [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
- about any of this, as you can just pass inputs like you would to any other Python function!
- </Tip>
- Parameters:
- config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- DISTILBERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
- [`PreTrainedTokenizer.encode`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
- config will be used instead.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
- used instead.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
- eager mode, in graph mode the value will always be set to True.
- training (`bool`, *optional*, defaults to `False`):
- Whether or not to use the model in training mode (some modules like dropout modules have different
- behaviors between training and evaluation).
- """
- @add_start_docstrings(
- "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertModel(TFDistilBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings
- @unpack_inputs
- @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFBaseModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- training: Optional[bool] = False,
- ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
- outputs = self.distilbert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- class TFDistilBertLMHead(keras.layers.Layer):
- def __init__(self, config, input_embeddings, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.dim = config.dim
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.input_embeddings = input_embeddings
- def build(self, input_shape):
- self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
- super().build(input_shape)
- def get_output_embeddings(self):
- return self.input_embeddings
- def set_output_embeddings(self, value):
- self.input_embeddings.weight = value
- self.input_embeddings.vocab_size = shape_list(value)[0]
- def get_bias(self):
- return {"bias": self.bias}
- def set_bias(self, value):
- self.bias = value["bias"]
- self.config.vocab_size = shape_list(value["bias"])[0]
- def call(self, hidden_states):
- seq_length = shape_list(tensor=hidden_states)[1]
- hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.dim])
- hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
- hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
- hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
- return hidden_states
- @add_start_docstrings(
- """DistilBert Model with a `masked language modeling` head on top.""",
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.config = config
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
- self.vocab_transform = keras.layers.Dense(
- config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform"
- )
- self.act = get_tf_activation(config.activation)
- self.vocab_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
- self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
- def get_lm_head(self):
- return self.vocab_projector
- def get_prefix_bias_name(self):
- warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
- return self.name + "/" + self.vocab_projector.name
- @unpack_inputs
- @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFMaskedLMOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- """
- distilbert_output = self.distilbert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- hidden_states = distilbert_output[0] # (bs, seq_length, dim)
- prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
- prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
- prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
- prediction_logits = self.vocab_projector(prediction_logits)
- loss = None if labels is None else self.hf_compute_loss(labels, prediction_logits)
- if not return_dict:
- output = (prediction_logits,) + distilbert_output[1:]
- return ((loss,) + output) if loss is not None else output
- return TFMaskedLMOutput(
- loss=loss,
- logits=prediction_logits,
- hidden_states=distilbert_output.hidden_states,
- attentions=distilbert_output.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- if getattr(self, "vocab_transform", None) is not None:
- with tf.name_scope(self.vocab_transform.name):
- self.vocab_transform.build([None, None, self.config.dim])
- if getattr(self, "vocab_layer_norm", None) is not None:
- with tf.name_scope(self.vocab_layer_norm.name):
- self.vocab_layer_norm.build([None, None, self.config.dim])
- if getattr(self, "vocab_projector", None) is not None:
- with tf.name_scope(self.vocab_projector.name):
- self.vocab_projector.build(None)
- @add_start_docstrings(
- """
- DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
- pooled output) e.g. for GLUE tasks.
- """,
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
- self.pre_classifier = keras.layers.Dense(
- config.dim,
- kernel_initializer=get_initializer(config.initializer_range),
- activation="relu",
- name="pre_classifier",
- )
- self.classifier = keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFSequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- distilbert_output = self.distilbert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- hidden_state = distilbert_output[0] # (bs, seq_len, dim)
- pooled_output = hidden_state[:, 0] # (bs, dim)
- pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
- pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
- logits = self.classifier(pooled_output) # (bs, dim)
- loss = None if labels is None else self.hf_compute_loss(labels, logits)
- if not return_dict:
- output = (logits,) + distilbert_output[1:]
- return ((loss,) + output) if loss is not None else output
- return TFSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=distilbert_output.hidden_states,
- attentions=distilbert_output.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- if getattr(self, "pre_classifier", None) is not None:
- with tf.name_scope(self.pre_classifier.name):
- self.pre_classifier.build([None, None, self.config.dim])
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.dim])
- @add_start_docstrings(
- """
- DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
- for Named-Entity-Recognition (NER) tasks.
- """,
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
- self.dropout = keras.layers.Dropout(config.dropout)
- self.classifier = keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFTokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- """
- outputs = self.distilbert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output, training=training)
- logits = self.classifier(sequence_output)
- loss = None if labels is None else self.hf_compute_loss(labels, logits)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return TFTokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.hidden_size])
- @add_start_docstrings(
- """
- DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
- a softmax) e.g. for RocStories/SWAG tasks.
- """,
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
- self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
- self.pre_classifier = keras.layers.Dense(
- config.dim,
- kernel_initializer=get_initializer(config.initializer_range),
- activation="relu",
- name="pre_classifier",
- )
- self.classifier = keras.layers.Dense(
- 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(
- DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
- )
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFMultipleChoiceModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- labels: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
- r"""
- labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
- where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
- """
- if input_ids is not None:
- num_choices = shape_list(input_ids)[1]
- seq_length = shape_list(input_ids)[2]
- else:
- num_choices = shape_list(inputs_embeds)[1]
- seq_length = shape_list(inputs_embeds)[2]
- flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
- flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
- flat_inputs_embeds = (
- tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
- if inputs_embeds is not None
- else None
- )
- distilbert_output = self.distilbert(
- flat_input_ids,
- flat_attention_mask,
- head_mask,
- flat_inputs_embeds,
- output_attentions,
- output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- hidden_state = distilbert_output[0] # (bs, seq_len, dim)
- pooled_output = hidden_state[:, 0] # (bs, dim)
- pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
- pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
- logits = self.classifier(pooled_output)
- reshaped_logits = tf.reshape(logits, (-1, num_choices))
- loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
- if not return_dict:
- output = (reshaped_logits,) + distilbert_output[1:]
- return ((loss,) + output) if loss is not None else output
- return TFMultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=distilbert_output.hidden_states,
- attentions=distilbert_output.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- if getattr(self, "pre_classifier", None) is not None:
- with tf.name_scope(self.pre_classifier.name):
- self.pre_classifier.build([None, None, self.config.dim])
- if getattr(self, "classifier", None) is not None:
- with tf.name_scope(self.classifier.name):
- self.classifier.build([None, None, self.config.dim])
- @add_start_docstrings(
- """
- DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
- linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- DISTILBERT_START_DOCSTRING,
- )
- class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
- self.qa_outputs = keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
- )
- assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2"
- self.dropout = keras.layers.Dropout(config.qa_dropout)
- self.config = config
- @unpack_inputs
- @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TFQuestionAnsweringModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def call(
- self,
- input_ids: TFModelInputType | None = None,
- attention_mask: np.ndarray | tf.Tensor | None = None,
- head_mask: np.ndarray | tf.Tensor | None = None,
- inputs_embeds: np.ndarray | tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- start_positions: np.ndarray | tf.Tensor | None = None,
- end_positions: np.ndarray | tf.Tensor | None = None,
- training: Optional[bool] = False,
- ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
- r"""
- start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- distilbert_output = self.distilbert(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- training=training,
- )
- hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
- hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
- logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
- start_logits, end_logits = tf.split(logits, 2, axis=-1)
- start_logits = tf.squeeze(start_logits, axis=-1)
- end_logits = tf.squeeze(end_logits, axis=-1)
- loss = None
- if start_positions is not None and end_positions is not None:
- labels = {"start_position": start_positions}
- labels["end_position"] = end_positions
- loss = self.hf_compute_loss(labels, (start_logits, end_logits))
- if not return_dict:
- output = (start_logits, end_logits) + distilbert_output[1:]
- return ((loss,) + output) if loss is not None else output
- return TFQuestionAnsweringModelOutput(
- loss=loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=distilbert_output.hidden_states,
- attentions=distilbert_output.attentions,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "distilbert", None) is not None:
- with tf.name_scope(self.distilbert.name):
- self.distilbert.build(None)
- if getattr(self, "qa_outputs", None) is not None:
- with tf.name_scope(self.qa_outputs.name):
- self.qa_outputs.build([None, None, self.config.dim])
|