modeling_tf_distilbert.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. # coding=utf-8
  2. # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. TF 2.0 DistilBERT model
  17. """
  18. from __future__ import annotations
  19. import warnings
  20. from typing import Optional, Tuple, Union
  21. import numpy as np
  22. import tensorflow as tf
  23. from ...activations_tf import get_tf_activation
  24. from ...modeling_tf_outputs import (
  25. TFBaseModelOutput,
  26. TFMaskedLMOutput,
  27. TFMultipleChoiceModelOutput,
  28. TFQuestionAnsweringModelOutput,
  29. TFSequenceClassifierOutput,
  30. TFTokenClassifierOutput,
  31. )
  32. from ...modeling_tf_utils import (
  33. TFMaskedLanguageModelingLoss,
  34. TFModelInputType,
  35. TFMultipleChoiceLoss,
  36. TFPreTrainedModel,
  37. TFQuestionAnsweringLoss,
  38. TFSequenceClassificationLoss,
  39. TFTokenClassificationLoss,
  40. get_initializer,
  41. keras,
  42. keras_serializable,
  43. unpack_inputs,
  44. )
  45. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  46. from ...utils import (
  47. add_code_sample_docstrings,
  48. add_start_docstrings,
  49. add_start_docstrings_to_model_forward,
  50. logging,
  51. )
  52. from .configuration_distilbert import DistilBertConfig
  53. logger = logging.get_logger(__name__)
  54. _CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
  55. _CONFIG_FOR_DOC = "DistilBertConfig"
  56. class TFEmbeddings(keras.layers.Layer):
  57. """Construct the embeddings from word, position and token_type embeddings."""
  58. def __init__(self, config, **kwargs):
  59. super().__init__(**kwargs)
  60. self.config = config
  61. self.dim = config.dim
  62. self.initializer_range = config.initializer_range
  63. self.max_position_embeddings = config.max_position_embeddings
  64. self.LayerNorm = keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
  65. self.dropout = keras.layers.Dropout(rate=config.dropout)
  66. def build(self, input_shape=None):
  67. with tf.name_scope("word_embeddings"):
  68. self.weight = self.add_weight(
  69. name="weight",
  70. shape=[self.config.vocab_size, self.dim],
  71. initializer=get_initializer(initializer_range=self.initializer_range),
  72. )
  73. with tf.name_scope("position_embeddings"):
  74. self.position_embeddings = self.add_weight(
  75. name="embeddings",
  76. shape=[self.max_position_embeddings, self.dim],
  77. initializer=get_initializer(initializer_range=self.initializer_range),
  78. )
  79. if self.built:
  80. return
  81. self.built = True
  82. if getattr(self, "LayerNorm", None) is not None:
  83. with tf.name_scope(self.LayerNorm.name):
  84. self.LayerNorm.build([None, None, self.config.dim])
  85. def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
  86. """
  87. Applies embedding based on inputs tensor.
  88. Returns:
  89. final_embeddings (`tf.Tensor`): output embedding tensor.
  90. """
  91. assert not (input_ids is None and inputs_embeds is None)
  92. if input_ids is not None:
  93. check_embeddings_within_bounds(input_ids, self.config.vocab_size)
  94. inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
  95. input_shape = shape_list(inputs_embeds)[:-1]
  96. if position_ids is None:
  97. position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
  98. position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
  99. final_embeddings = inputs_embeds + position_embeds
  100. final_embeddings = self.LayerNorm(inputs=final_embeddings)
  101. final_embeddings = self.dropout(inputs=final_embeddings, training=training)
  102. return final_embeddings
  103. class TFMultiHeadSelfAttention(keras.layers.Layer):
  104. def __init__(self, config, **kwargs):
  105. super().__init__(**kwargs)
  106. self.n_heads = config.n_heads
  107. self.dim = config.dim
  108. self.dropout = keras.layers.Dropout(config.attention_dropout)
  109. self.output_attentions = config.output_attentions
  110. assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
  111. self.q_lin = keras.layers.Dense(
  112. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin"
  113. )
  114. self.k_lin = keras.layers.Dense(
  115. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="k_lin"
  116. )
  117. self.v_lin = keras.layers.Dense(
  118. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="v_lin"
  119. )
  120. self.out_lin = keras.layers.Dense(
  121. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="out_lin"
  122. )
  123. self.pruned_heads = set()
  124. self.config = config
  125. def prune_heads(self, heads):
  126. raise NotImplementedError
  127. def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
  128. """
  129. Parameters:
  130. query: tf.Tensor(bs, seq_length, dim)
  131. key: tf.Tensor(bs, seq_length, dim)
  132. value: tf.Tensor(bs, seq_length, dim)
  133. mask: tf.Tensor(bs, seq_length)
  134. Returns:
  135. weights: tf.Tensor(bs, n_heads, seq_length, seq_length) Attention weights context: tf.Tensor(bs,
  136. seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
  137. """
  138. bs, q_length, dim = shape_list(query)
  139. k_length = shape_list(key)[1]
  140. # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
  141. # assert key.size() == value.size()
  142. dim_per_head = int(self.dim / self.n_heads)
  143. dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
  144. mask_reshape = [bs, 1, 1, k_length]
  145. def shape(x):
  146. """separate heads"""
  147. return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
  148. def unshape(x):
  149. """group heads"""
  150. return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
  151. q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
  152. k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
  153. v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
  154. q = tf.cast(q, dtype=tf.float32)
  155. q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
  156. k = tf.cast(k, dtype=q.dtype)
  157. scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
  158. mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
  159. # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
  160. mask = tf.cast(mask, dtype=scores.dtype)
  161. scores = scores - 1e30 * (1.0 - mask)
  162. weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
  163. weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
  164. # Mask heads if we want to
  165. if head_mask is not None:
  166. weights = weights * head_mask
  167. context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
  168. context = unshape(context) # (bs, q_length, dim)
  169. context = self.out_lin(context) # (bs, q_length, dim)
  170. if output_attentions:
  171. return (context, weights)
  172. else:
  173. return (context,)
  174. def build(self, input_shape=None):
  175. if self.built:
  176. return
  177. self.built = True
  178. if getattr(self, "q_lin", None) is not None:
  179. with tf.name_scope(self.q_lin.name):
  180. self.q_lin.build([None, None, self.config.dim])
  181. if getattr(self, "k_lin", None) is not None:
  182. with tf.name_scope(self.k_lin.name):
  183. self.k_lin.build([None, None, self.config.dim])
  184. if getattr(self, "v_lin", None) is not None:
  185. with tf.name_scope(self.v_lin.name):
  186. self.v_lin.build([None, None, self.config.dim])
  187. if getattr(self, "out_lin", None) is not None:
  188. with tf.name_scope(self.out_lin.name):
  189. self.out_lin.build([None, None, self.config.dim])
  190. class TFFFN(keras.layers.Layer):
  191. def __init__(self, config, **kwargs):
  192. super().__init__(**kwargs)
  193. self.dropout = keras.layers.Dropout(config.dropout)
  194. self.lin1 = keras.layers.Dense(
  195. config.hidden_dim, kernel_initializer=get_initializer(config.initializer_range), name="lin1"
  196. )
  197. self.lin2 = keras.layers.Dense(
  198. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
  199. )
  200. self.activation = get_tf_activation(config.activation)
  201. self.config = config
  202. def call(self, input, training=False):
  203. x = self.lin1(input)
  204. x = self.activation(x)
  205. x = self.lin2(x)
  206. x = self.dropout(x, training=training)
  207. return x
  208. def build(self, input_shape=None):
  209. if self.built:
  210. return
  211. self.built = True
  212. if getattr(self, "lin1", None) is not None:
  213. with tf.name_scope(self.lin1.name):
  214. self.lin1.build([None, None, self.config.dim])
  215. if getattr(self, "lin2", None) is not None:
  216. with tf.name_scope(self.lin2.name):
  217. self.lin2.build([None, None, self.config.hidden_dim])
  218. class TFTransformerBlock(keras.layers.Layer):
  219. def __init__(self, config, **kwargs):
  220. super().__init__(**kwargs)
  221. self.n_heads = config.n_heads
  222. self.dim = config.dim
  223. self.hidden_dim = config.hidden_dim
  224. self.dropout = keras.layers.Dropout(config.dropout)
  225. self.activation = config.activation
  226. self.output_attentions = config.output_attentions
  227. assert (
  228. config.dim % config.n_heads == 0
  229. ), f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}"
  230. self.attention = TFMultiHeadSelfAttention(config, name="attention")
  231. self.sa_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm")
  232. self.ffn = TFFFN(config, name="ffn")
  233. self.output_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
  234. self.config = config
  235. def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
  236. """
  237. Parameters:
  238. x: tf.Tensor(bs, seq_length, dim)
  239. attn_mask: tf.Tensor(bs, seq_length)
  240. Outputs: sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
  241. tf.Tensor(bs, seq_length, dim) The output of the transformer block contextualization.
  242. """
  243. # Self-Attention
  244. sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
  245. if output_attentions:
  246. sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
  247. else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
  248. # assert type(sa_output) == tuple
  249. sa_output = sa_output[0]
  250. sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
  251. # Feed Forward Network
  252. ffn_output = self.ffn(sa_output, training=training) # (bs, seq_length, dim)
  253. ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
  254. output = (ffn_output,)
  255. if output_attentions:
  256. output = (sa_weights,) + output
  257. return output
  258. def build(self, input_shape=None):
  259. if self.built:
  260. return
  261. self.built = True
  262. if getattr(self, "attention", None) is not None:
  263. with tf.name_scope(self.attention.name):
  264. self.attention.build(None)
  265. if getattr(self, "sa_layer_norm", None) is not None:
  266. with tf.name_scope(self.sa_layer_norm.name):
  267. self.sa_layer_norm.build([None, None, self.config.dim])
  268. if getattr(self, "ffn", None) is not None:
  269. with tf.name_scope(self.ffn.name):
  270. self.ffn.build(None)
  271. if getattr(self, "output_layer_norm", None) is not None:
  272. with tf.name_scope(self.output_layer_norm.name):
  273. self.output_layer_norm.build([None, None, self.config.dim])
  274. class TFTransformer(keras.layers.Layer):
  275. def __init__(self, config, **kwargs):
  276. super().__init__(**kwargs)
  277. self.n_layers = config.n_layers
  278. self.output_hidden_states = config.output_hidden_states
  279. self.output_attentions = config.output_attentions
  280. self.layer = [TFTransformerBlock(config, name=f"layer_._{i}") for i in range(config.n_layers)]
  281. def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):
  282. # docstyle-ignore
  283. """
  284. Parameters:
  285. x: tf.Tensor(bs, seq_length, dim) Input sequence embedded.
  286. attn_mask: tf.Tensor(bs, seq_length) Attention mask on the sequence.
  287. Returns:
  288. hidden_state: tf.Tensor(bs, seq_length, dim)
  289. Sequence of hidden states in the last (top) layer
  290. all_hidden_states: Tuple[tf.Tensor(bs, seq_length, dim)]
  291. Tuple of length n_layers with the hidden states from each layer.
  292. Optional: only if output_hidden_states=True
  293. all_attentions: Tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)]
  294. Tuple of length n_layers with the attention weights from each layer
  295. Optional: only if output_attentions=True
  296. """
  297. all_hidden_states = () if output_hidden_states else None
  298. all_attentions = () if output_attentions else None
  299. hidden_state = x
  300. for i, layer_module in enumerate(self.layer):
  301. if output_hidden_states:
  302. all_hidden_states = all_hidden_states + (hidden_state,)
  303. layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
  304. hidden_state = layer_outputs[-1]
  305. if output_attentions:
  306. assert len(layer_outputs) == 2
  307. attentions = layer_outputs[0]
  308. all_attentions = all_attentions + (attentions,)
  309. else:
  310. assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
  311. # Add last layer
  312. if output_hidden_states:
  313. all_hidden_states = all_hidden_states + (hidden_state,)
  314. if not return_dict:
  315. return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
  316. return TFBaseModelOutput(
  317. last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
  318. )
  319. def build(self, input_shape=None):
  320. if self.built:
  321. return
  322. self.built = True
  323. if getattr(self, "layer", None) is not None:
  324. for layer in self.layer:
  325. with tf.name_scope(layer.name):
  326. layer.build(None)
  327. @keras_serializable
  328. class TFDistilBertMainLayer(keras.layers.Layer):
  329. config_class = DistilBertConfig
  330. def __init__(self, config, **kwargs):
  331. super().__init__(**kwargs)
  332. self.config = config
  333. self.num_hidden_layers = config.num_hidden_layers
  334. self.output_attentions = config.output_attentions
  335. self.output_hidden_states = config.output_hidden_states
  336. self.return_dict = config.use_return_dict
  337. self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
  338. self.transformer = TFTransformer(config, name="transformer") # Encoder
  339. def get_input_embeddings(self):
  340. return self.embeddings
  341. def set_input_embeddings(self, value):
  342. self.embeddings.weight = value
  343. self.embeddings.vocab_size = value.shape[0]
  344. def _prune_heads(self, heads_to_prune):
  345. raise NotImplementedError
  346. @unpack_inputs
  347. def call(
  348. self,
  349. input_ids=None,
  350. attention_mask=None,
  351. head_mask=None,
  352. inputs_embeds=None,
  353. output_attentions=None,
  354. output_hidden_states=None,
  355. return_dict=None,
  356. training=False,
  357. ):
  358. if input_ids is not None and inputs_embeds is not None:
  359. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  360. elif input_ids is not None:
  361. input_shape = shape_list(input_ids)
  362. elif inputs_embeds is not None:
  363. input_shape = shape_list(inputs_embeds)[:-1]
  364. else:
  365. raise ValueError("You have to specify either input_ids or inputs_embeds")
  366. if attention_mask is None:
  367. attention_mask = tf.ones(input_shape) # (bs, seq_length)
  368. attention_mask = tf.cast(attention_mask, dtype=tf.float32)
  369. # Prepare head mask if needed
  370. # 1.0 in head_mask indicate we keep the head
  371. # attention_probs has shape bsz x n_heads x N x N
  372. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  373. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  374. if head_mask is not None:
  375. raise NotImplementedError
  376. else:
  377. head_mask = [None] * self.num_hidden_layers
  378. embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
  379. tfmr_output = self.transformer(
  380. embedding_output,
  381. attention_mask,
  382. head_mask,
  383. output_attentions,
  384. output_hidden_states,
  385. return_dict,
  386. training=training,
  387. )
  388. return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
  389. def build(self, input_shape=None):
  390. if self.built:
  391. return
  392. self.built = True
  393. if getattr(self, "embeddings", None) is not None:
  394. with tf.name_scope(self.embeddings.name):
  395. self.embeddings.build(None)
  396. if getattr(self, "transformer", None) is not None:
  397. with tf.name_scope(self.transformer.name):
  398. self.transformer.build(None)
  399. # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
  400. class TFDistilBertPreTrainedModel(TFPreTrainedModel):
  401. """
  402. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  403. models.
  404. """
  405. config_class = DistilBertConfig
  406. base_model_prefix = "distilbert"
  407. DISTILBERT_START_DOCSTRING = r"""
  408. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  409. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  410. etc.)
  411. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  412. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  413. behavior.
  414. <Tip>
  415. TensorFlow models and layers in `transformers` accept two formats as input:
  416. - having all inputs as keyword arguments (like PyTorch models), or
  417. - having all inputs as a list, tuple or dict in the first positional argument.
  418. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  419. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  420. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  421. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  422. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  423. positional argument:
  424. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  425. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  426. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  427. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  428. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  429. Note that when creating models and layers with
  430. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  431. about any of this, as you can just pass inputs like you would to any other Python function!
  432. </Tip>
  433. Parameters:
  434. config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.
  435. Initializing with a config file does not load the weights associated with the model, only the
  436. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  437. """
  438. DISTILBERT_INPUTS_DOCSTRING = r"""
  439. Args:
  440. input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
  441. Indices of input sequence tokens in the vocabulary.
  442. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  443. [`PreTrainedTokenizer.encode`] for details.
  444. [What are input IDs?](../glossary#input-ids)
  445. attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
  446. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  447. - 1 for tokens that are **not masked**,
  448. - 0 for tokens that are **masked**.
  449. [What are attention masks?](../glossary#attention-mask)
  450. head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  451. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  452. - 1 indicates the head is **not masked**,
  453. - 0 indicates the head is **masked**.
  454. inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
  455. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  456. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  457. model's internal embedding lookup matrix.
  458. output_attentions (`bool`, *optional*):
  459. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  460. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  461. config will be used instead.
  462. output_hidden_states (`bool`, *optional*):
  463. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  464. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  465. used instead.
  466. return_dict (`bool`, *optional*):
  467. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  468. eager mode, in graph mode the value will always be set to True.
  469. training (`bool`, *optional*, defaults to `False`):
  470. Whether or not to use the model in training mode (some modules like dropout modules have different
  471. behaviors between training and evaluation).
  472. """
  473. @add_start_docstrings(
  474. "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
  475. DISTILBERT_START_DOCSTRING,
  476. )
  477. class TFDistilBertModel(TFDistilBertPreTrainedModel):
  478. def __init__(self, config, *inputs, **kwargs):
  479. super().__init__(config, *inputs, **kwargs)
  480. self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings
  481. @unpack_inputs
  482. @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  483. @add_code_sample_docstrings(
  484. checkpoint=_CHECKPOINT_FOR_DOC,
  485. output_type=TFBaseModelOutput,
  486. config_class=_CONFIG_FOR_DOC,
  487. )
  488. def call(
  489. self,
  490. input_ids: TFModelInputType | None = None,
  491. attention_mask: np.ndarray | tf.Tensor | None = None,
  492. head_mask: np.ndarray | tf.Tensor | None = None,
  493. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  494. output_attentions: Optional[bool] = None,
  495. output_hidden_states: Optional[bool] = None,
  496. return_dict: Optional[bool] = None,
  497. training: Optional[bool] = False,
  498. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  499. outputs = self.distilbert(
  500. input_ids=input_ids,
  501. attention_mask=attention_mask,
  502. head_mask=head_mask,
  503. inputs_embeds=inputs_embeds,
  504. output_attentions=output_attentions,
  505. output_hidden_states=output_hidden_states,
  506. return_dict=return_dict,
  507. training=training,
  508. )
  509. return outputs
  510. def build(self, input_shape=None):
  511. if self.built:
  512. return
  513. self.built = True
  514. if getattr(self, "distilbert", None) is not None:
  515. with tf.name_scope(self.distilbert.name):
  516. self.distilbert.build(None)
  517. class TFDistilBertLMHead(keras.layers.Layer):
  518. def __init__(self, config, input_embeddings, **kwargs):
  519. super().__init__(**kwargs)
  520. self.config = config
  521. self.dim = config.dim
  522. # The output weights are the same as the input embeddings, but there is
  523. # an output-only bias for each token.
  524. self.input_embeddings = input_embeddings
  525. def build(self, input_shape):
  526. self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
  527. super().build(input_shape)
  528. def get_output_embeddings(self):
  529. return self.input_embeddings
  530. def set_output_embeddings(self, value):
  531. self.input_embeddings.weight = value
  532. self.input_embeddings.vocab_size = shape_list(value)[0]
  533. def get_bias(self):
  534. return {"bias": self.bias}
  535. def set_bias(self, value):
  536. self.bias = value["bias"]
  537. self.config.vocab_size = shape_list(value["bias"])[0]
  538. def call(self, hidden_states):
  539. seq_length = shape_list(tensor=hidden_states)[1]
  540. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.dim])
  541. hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
  542. hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
  543. hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
  544. return hidden_states
  545. @add_start_docstrings(
  546. """DistilBert Model with a `masked language modeling` head on top.""",
  547. DISTILBERT_START_DOCSTRING,
  548. )
  549. class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss):
  550. def __init__(self, config, *inputs, **kwargs):
  551. super().__init__(config, *inputs, **kwargs)
  552. self.config = config
  553. self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
  554. self.vocab_transform = keras.layers.Dense(
  555. config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform"
  556. )
  557. self.act = get_tf_activation(config.activation)
  558. self.vocab_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
  559. self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
  560. def get_lm_head(self):
  561. return self.vocab_projector
  562. def get_prefix_bias_name(self):
  563. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  564. return self.name + "/" + self.vocab_projector.name
  565. @unpack_inputs
  566. @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  567. @add_code_sample_docstrings(
  568. checkpoint=_CHECKPOINT_FOR_DOC,
  569. output_type=TFMaskedLMOutput,
  570. config_class=_CONFIG_FOR_DOC,
  571. )
  572. def call(
  573. self,
  574. input_ids: TFModelInputType | None = None,
  575. attention_mask: np.ndarray | tf.Tensor | None = None,
  576. head_mask: np.ndarray | tf.Tensor | None = None,
  577. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  578. output_attentions: Optional[bool] = None,
  579. output_hidden_states: Optional[bool] = None,
  580. return_dict: Optional[bool] = None,
  581. labels: np.ndarray | tf.Tensor | None = None,
  582. training: Optional[bool] = False,
  583. ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
  584. r"""
  585. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  586. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  587. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  588. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  589. """
  590. distilbert_output = self.distilbert(
  591. input_ids=input_ids,
  592. attention_mask=attention_mask,
  593. head_mask=head_mask,
  594. inputs_embeds=inputs_embeds,
  595. output_attentions=output_attentions,
  596. output_hidden_states=output_hidden_states,
  597. return_dict=return_dict,
  598. training=training,
  599. )
  600. hidden_states = distilbert_output[0] # (bs, seq_length, dim)
  601. prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
  602. prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
  603. prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
  604. prediction_logits = self.vocab_projector(prediction_logits)
  605. loss = None if labels is None else self.hf_compute_loss(labels, prediction_logits)
  606. if not return_dict:
  607. output = (prediction_logits,) + distilbert_output[1:]
  608. return ((loss,) + output) if loss is not None else output
  609. return TFMaskedLMOutput(
  610. loss=loss,
  611. logits=prediction_logits,
  612. hidden_states=distilbert_output.hidden_states,
  613. attentions=distilbert_output.attentions,
  614. )
  615. def build(self, input_shape=None):
  616. if self.built:
  617. return
  618. self.built = True
  619. if getattr(self, "distilbert", None) is not None:
  620. with tf.name_scope(self.distilbert.name):
  621. self.distilbert.build(None)
  622. if getattr(self, "vocab_transform", None) is not None:
  623. with tf.name_scope(self.vocab_transform.name):
  624. self.vocab_transform.build([None, None, self.config.dim])
  625. if getattr(self, "vocab_layer_norm", None) is not None:
  626. with tf.name_scope(self.vocab_layer_norm.name):
  627. self.vocab_layer_norm.build([None, None, self.config.dim])
  628. if getattr(self, "vocab_projector", None) is not None:
  629. with tf.name_scope(self.vocab_projector.name):
  630. self.vocab_projector.build(None)
  631. @add_start_docstrings(
  632. """
  633. DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  634. pooled output) e.g. for GLUE tasks.
  635. """,
  636. DISTILBERT_START_DOCSTRING,
  637. )
  638. class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss):
  639. def __init__(self, config, *inputs, **kwargs):
  640. super().__init__(config, *inputs, **kwargs)
  641. self.num_labels = config.num_labels
  642. self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
  643. self.pre_classifier = keras.layers.Dense(
  644. config.dim,
  645. kernel_initializer=get_initializer(config.initializer_range),
  646. activation="relu",
  647. name="pre_classifier",
  648. )
  649. self.classifier = keras.layers.Dense(
  650. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  651. )
  652. self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
  653. self.config = config
  654. @unpack_inputs
  655. @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  656. @add_code_sample_docstrings(
  657. checkpoint=_CHECKPOINT_FOR_DOC,
  658. output_type=TFSequenceClassifierOutput,
  659. config_class=_CONFIG_FOR_DOC,
  660. )
  661. def call(
  662. self,
  663. input_ids: TFModelInputType | None = None,
  664. attention_mask: np.ndarray | tf.Tensor | None = None,
  665. head_mask: np.ndarray | tf.Tensor | None = None,
  666. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  667. output_attentions: Optional[bool] = None,
  668. output_hidden_states: Optional[bool] = None,
  669. return_dict: Optional[bool] = None,
  670. labels: np.ndarray | tf.Tensor | None = None,
  671. training: Optional[bool] = False,
  672. ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
  673. r"""
  674. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  675. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  676. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  677. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  678. """
  679. distilbert_output = self.distilbert(
  680. input_ids=input_ids,
  681. attention_mask=attention_mask,
  682. head_mask=head_mask,
  683. inputs_embeds=inputs_embeds,
  684. output_attentions=output_attentions,
  685. output_hidden_states=output_hidden_states,
  686. return_dict=return_dict,
  687. training=training,
  688. )
  689. hidden_state = distilbert_output[0] # (bs, seq_len, dim)
  690. pooled_output = hidden_state[:, 0] # (bs, dim)
  691. pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
  692. pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
  693. logits = self.classifier(pooled_output) # (bs, dim)
  694. loss = None if labels is None else self.hf_compute_loss(labels, logits)
  695. if not return_dict:
  696. output = (logits,) + distilbert_output[1:]
  697. return ((loss,) + output) if loss is not None else output
  698. return TFSequenceClassifierOutput(
  699. loss=loss,
  700. logits=logits,
  701. hidden_states=distilbert_output.hidden_states,
  702. attentions=distilbert_output.attentions,
  703. )
  704. def build(self, input_shape=None):
  705. if self.built:
  706. return
  707. self.built = True
  708. if getattr(self, "distilbert", None) is not None:
  709. with tf.name_scope(self.distilbert.name):
  710. self.distilbert.build(None)
  711. if getattr(self, "pre_classifier", None) is not None:
  712. with tf.name_scope(self.pre_classifier.name):
  713. self.pre_classifier.build([None, None, self.config.dim])
  714. if getattr(self, "classifier", None) is not None:
  715. with tf.name_scope(self.classifier.name):
  716. self.classifier.build([None, None, self.config.dim])
  717. @add_start_docstrings(
  718. """
  719. DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
  720. for Named-Entity-Recognition (NER) tasks.
  721. """,
  722. DISTILBERT_START_DOCSTRING,
  723. )
  724. class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
  725. def __init__(self, config, *inputs, **kwargs):
  726. super().__init__(config, *inputs, **kwargs)
  727. self.num_labels = config.num_labels
  728. self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
  729. self.dropout = keras.layers.Dropout(config.dropout)
  730. self.classifier = keras.layers.Dense(
  731. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  732. )
  733. self.config = config
  734. @unpack_inputs
  735. @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  736. @add_code_sample_docstrings(
  737. checkpoint=_CHECKPOINT_FOR_DOC,
  738. output_type=TFTokenClassifierOutput,
  739. config_class=_CONFIG_FOR_DOC,
  740. )
  741. def call(
  742. self,
  743. input_ids: TFModelInputType | None = None,
  744. attention_mask: np.ndarray | tf.Tensor | None = None,
  745. head_mask: np.ndarray | tf.Tensor | None = None,
  746. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  747. output_attentions: Optional[bool] = None,
  748. output_hidden_states: Optional[bool] = None,
  749. return_dict: Optional[bool] = None,
  750. labels: np.ndarray | tf.Tensor | None = None,
  751. training: Optional[bool] = False,
  752. ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
  753. r"""
  754. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  755. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  756. """
  757. outputs = self.distilbert(
  758. input_ids=input_ids,
  759. attention_mask=attention_mask,
  760. head_mask=head_mask,
  761. inputs_embeds=inputs_embeds,
  762. output_attentions=output_attentions,
  763. output_hidden_states=output_hidden_states,
  764. return_dict=return_dict,
  765. training=training,
  766. )
  767. sequence_output = outputs[0]
  768. sequence_output = self.dropout(sequence_output, training=training)
  769. logits = self.classifier(sequence_output)
  770. loss = None if labels is None else self.hf_compute_loss(labels, logits)
  771. if not return_dict:
  772. output = (logits,) + outputs[1:]
  773. return ((loss,) + output) if loss is not None else output
  774. return TFTokenClassifierOutput(
  775. loss=loss,
  776. logits=logits,
  777. hidden_states=outputs.hidden_states,
  778. attentions=outputs.attentions,
  779. )
  780. def build(self, input_shape=None):
  781. if self.built:
  782. return
  783. self.built = True
  784. if getattr(self, "distilbert", None) is not None:
  785. with tf.name_scope(self.distilbert.name):
  786. self.distilbert.build(None)
  787. if getattr(self, "classifier", None) is not None:
  788. with tf.name_scope(self.classifier.name):
  789. self.classifier.build([None, None, self.config.hidden_size])
  790. @add_start_docstrings(
  791. """
  792. DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
  793. a softmax) e.g. for RocStories/SWAG tasks.
  794. """,
  795. DISTILBERT_START_DOCSTRING,
  796. )
  797. class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss):
  798. def __init__(self, config, *inputs, **kwargs):
  799. super().__init__(config, *inputs, **kwargs)
  800. self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
  801. self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
  802. self.pre_classifier = keras.layers.Dense(
  803. config.dim,
  804. kernel_initializer=get_initializer(config.initializer_range),
  805. activation="relu",
  806. name="pre_classifier",
  807. )
  808. self.classifier = keras.layers.Dense(
  809. 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  810. )
  811. self.config = config
  812. @unpack_inputs
  813. @add_start_docstrings_to_model_forward(
  814. DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  815. )
  816. @add_code_sample_docstrings(
  817. checkpoint=_CHECKPOINT_FOR_DOC,
  818. output_type=TFMultipleChoiceModelOutput,
  819. config_class=_CONFIG_FOR_DOC,
  820. )
  821. def call(
  822. self,
  823. input_ids: TFModelInputType | None = None,
  824. attention_mask: np.ndarray | tf.Tensor | None = None,
  825. head_mask: np.ndarray | tf.Tensor | None = None,
  826. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  827. output_attentions: Optional[bool] = None,
  828. output_hidden_states: Optional[bool] = None,
  829. return_dict: Optional[bool] = None,
  830. labels: np.ndarray | tf.Tensor | None = None,
  831. training: Optional[bool] = False,
  832. ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
  833. r"""
  834. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  835. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  836. where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
  837. """
  838. if input_ids is not None:
  839. num_choices = shape_list(input_ids)[1]
  840. seq_length = shape_list(input_ids)[2]
  841. else:
  842. num_choices = shape_list(inputs_embeds)[1]
  843. seq_length = shape_list(inputs_embeds)[2]
  844. flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
  845. flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
  846. flat_inputs_embeds = (
  847. tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
  848. if inputs_embeds is not None
  849. else None
  850. )
  851. distilbert_output = self.distilbert(
  852. flat_input_ids,
  853. flat_attention_mask,
  854. head_mask,
  855. flat_inputs_embeds,
  856. output_attentions,
  857. output_hidden_states,
  858. return_dict=return_dict,
  859. training=training,
  860. )
  861. hidden_state = distilbert_output[0] # (bs, seq_len, dim)
  862. pooled_output = hidden_state[:, 0] # (bs, dim)
  863. pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
  864. pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
  865. logits = self.classifier(pooled_output)
  866. reshaped_logits = tf.reshape(logits, (-1, num_choices))
  867. loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
  868. if not return_dict:
  869. output = (reshaped_logits,) + distilbert_output[1:]
  870. return ((loss,) + output) if loss is not None else output
  871. return TFMultipleChoiceModelOutput(
  872. loss=loss,
  873. logits=reshaped_logits,
  874. hidden_states=distilbert_output.hidden_states,
  875. attentions=distilbert_output.attentions,
  876. )
  877. def build(self, input_shape=None):
  878. if self.built:
  879. return
  880. self.built = True
  881. if getattr(self, "distilbert", None) is not None:
  882. with tf.name_scope(self.distilbert.name):
  883. self.distilbert.build(None)
  884. if getattr(self, "pre_classifier", None) is not None:
  885. with tf.name_scope(self.pre_classifier.name):
  886. self.pre_classifier.build([None, None, self.config.dim])
  887. if getattr(self, "classifier", None) is not None:
  888. with tf.name_scope(self.classifier.name):
  889. self.classifier.build([None, None, self.config.dim])
  890. @add_start_docstrings(
  891. """
  892. DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
  893. linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  894. """,
  895. DISTILBERT_START_DOCSTRING,
  896. )
  897. class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss):
  898. def __init__(self, config, *inputs, **kwargs):
  899. super().__init__(config, *inputs, **kwargs)
  900. self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
  901. self.qa_outputs = keras.layers.Dense(
  902. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
  903. )
  904. assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2"
  905. self.dropout = keras.layers.Dropout(config.qa_dropout)
  906. self.config = config
  907. @unpack_inputs
  908. @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  909. @add_code_sample_docstrings(
  910. checkpoint=_CHECKPOINT_FOR_DOC,
  911. output_type=TFQuestionAnsweringModelOutput,
  912. config_class=_CONFIG_FOR_DOC,
  913. )
  914. def call(
  915. self,
  916. input_ids: TFModelInputType | None = None,
  917. attention_mask: np.ndarray | tf.Tensor | None = None,
  918. head_mask: np.ndarray | tf.Tensor | None = None,
  919. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  920. output_attentions: Optional[bool] = None,
  921. output_hidden_states: Optional[bool] = None,
  922. return_dict: Optional[bool] = None,
  923. start_positions: np.ndarray | tf.Tensor | None = None,
  924. end_positions: np.ndarray | tf.Tensor | None = None,
  925. training: Optional[bool] = False,
  926. ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
  927. r"""
  928. start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  929. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  930. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  931. are not taken into account for computing the loss.
  932. end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  933. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  934. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  935. are not taken into account for computing the loss.
  936. """
  937. distilbert_output = self.distilbert(
  938. input_ids=input_ids,
  939. attention_mask=attention_mask,
  940. head_mask=head_mask,
  941. inputs_embeds=inputs_embeds,
  942. output_attentions=output_attentions,
  943. output_hidden_states=output_hidden_states,
  944. return_dict=return_dict,
  945. training=training,
  946. )
  947. hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
  948. hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
  949. logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
  950. start_logits, end_logits = tf.split(logits, 2, axis=-1)
  951. start_logits = tf.squeeze(start_logits, axis=-1)
  952. end_logits = tf.squeeze(end_logits, axis=-1)
  953. loss = None
  954. if start_positions is not None and end_positions is not None:
  955. labels = {"start_position": start_positions}
  956. labels["end_position"] = end_positions
  957. loss = self.hf_compute_loss(labels, (start_logits, end_logits))
  958. if not return_dict:
  959. output = (start_logits, end_logits) + distilbert_output[1:]
  960. return ((loss,) + output) if loss is not None else output
  961. return TFQuestionAnsweringModelOutput(
  962. loss=loss,
  963. start_logits=start_logits,
  964. end_logits=end_logits,
  965. hidden_states=distilbert_output.hidden_states,
  966. attentions=distilbert_output.attentions,
  967. )
  968. def build(self, input_shape=None):
  969. if self.built:
  970. return
  971. self.built = True
  972. if getattr(self, "distilbert", None) is not None:
  973. with tf.name_scope(self.distilbert.name):
  974. self.distilbert.build(None)
  975. if getattr(self, "qa_outputs", None) is not None:
  976. with tf.name_scope(self.qa_outputs.name):
  977. self.qa_outputs.build([None, None, self.config.dim])