modeling_tf_ctrl.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928
  1. # coding=utf-8
  2. # Copyright 2018 Salesforce and HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF 2.0 CTRL model."""
  17. from __future__ import annotations
  18. from typing import Optional, Tuple, Union
  19. import numpy as np
  20. import tensorflow as tf
  21. from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutput
  22. from ...modeling_tf_utils import (
  23. TFCausalLanguageModelingLoss,
  24. TFModelInputType,
  25. TFPreTrainedModel,
  26. TFSequenceClassificationLoss,
  27. get_initializer,
  28. keras,
  29. keras_serializable,
  30. unpack_inputs,
  31. )
  32. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  33. from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  34. from .configuration_ctrl import CTRLConfig
  35. logger = logging.get_logger(__name__)
  36. _CHECKPOINT_FOR_DOC = "Salesforce/ctrl"
  37. _CONFIG_FOR_DOC = "CTRLConfig"
  38. def angle_defn(pos, i, d_model_size):
  39. angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
  40. return pos * angle_rates
  41. def positional_encoding(position, d_model_size):
  42. # create the sinusoidal pattern for the positional encoding
  43. angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size)
  44. sines = np.sin(angle_rads[:, 0::2])
  45. cosines = np.cos(angle_rads[:, 1::2])
  46. pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
  47. return pos_encoding
  48. def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
  49. # calculate attention
  50. matmul_qk = tf.matmul(q, k, transpose_b=True)
  51. dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
  52. scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
  53. if mask is not None:
  54. scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
  55. if attention_mask is not None:
  56. # Apply the attention mask
  57. attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
  58. scaled_attention_logits = scaled_attention_logits + attention_mask
  59. attention_weights = stable_softmax(scaled_attention_logits, axis=-1)
  60. # Mask heads if we want to
  61. if head_mask is not None:
  62. attention_weights = attention_weights * head_mask
  63. output = tf.matmul(attention_weights, v)
  64. return output, attention_weights
  65. class TFMultiHeadAttention(keras.layers.Layer):
  66. def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
  67. super().__init__(**kwargs)
  68. self.num_heads = num_heads
  69. self.d_model_size = d_model_size
  70. self.output_attentions = output_attentions
  71. self.depth = int(d_model_size / self.num_heads)
  72. self.Wq = keras.layers.Dense(d_model_size, name="Wq")
  73. self.Wk = keras.layers.Dense(d_model_size, name="Wk")
  74. self.Wv = keras.layers.Dense(d_model_size, name="Wv")
  75. self.dense = keras.layers.Dense(d_model_size, name="dense")
  76. def split_into_heads(self, x, batch_size):
  77. x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
  78. return tf.transpose(x, perm=[0, 2, 1, 3])
  79. def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
  80. batch_size = shape_list(q)[0]
  81. q = self.Wq(q)
  82. k = self.Wk(k)
  83. v = self.Wv(v)
  84. q = self.split_into_heads(q, batch_size)
  85. k = self.split_into_heads(k, batch_size)
  86. v = self.split_into_heads(v, batch_size)
  87. if layer_past is not None:
  88. past_key, past_value = tf.unstack(layer_past, axis=0)
  89. k = tf.concat((past_key, k), axis=-2)
  90. v = tf.concat((past_value, v), axis=-2)
  91. if use_cache:
  92. present = tf.stack((k, v), axis=0)
  93. else:
  94. present = (None,)
  95. output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
  96. scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
  97. attn = output[1]
  98. original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
  99. output = self.dense(original_size_attention)
  100. outputs = (output, present)
  101. if output_attentions:
  102. outputs = outputs + (attn,)
  103. return outputs
  104. def build(self, input_shape=None):
  105. if self.built:
  106. return
  107. self.built = True
  108. if getattr(self, "Wq", None) is not None:
  109. with tf.name_scope(self.Wq.name):
  110. self.Wq.build([None, None, self.d_model_size])
  111. if getattr(self, "Wk", None) is not None:
  112. with tf.name_scope(self.Wk.name):
  113. self.Wk.build([None, None, self.d_model_size])
  114. if getattr(self, "Wv", None) is not None:
  115. with tf.name_scope(self.Wv.name):
  116. self.Wv.build([None, None, self.d_model_size])
  117. if getattr(self, "dense", None) is not None:
  118. with tf.name_scope(self.dense.name):
  119. self.dense.build([None, None, self.d_model_size])
  120. class TFPointWiseFeedForwardLayer(keras.layers.Layer):
  121. def __init__(self, d_model_size, dff, **kwargs):
  122. super().__init__(**kwargs)
  123. self.dense_0 = keras.layers.Dense(dff, activation="relu", name="0")
  124. self.dense_2 = keras.layers.Dense(d_model_size, name="2")
  125. self.d_model_size = d_model_size
  126. self.dff = dff
  127. def call(self, inputs, trainable=False):
  128. dense_0_output = self.dense_0(inputs)
  129. dense_2_output = self.dense_2(dense_0_output)
  130. return dense_2_output
  131. def build(self, input_shape=None):
  132. if self.built:
  133. return
  134. self.built = True
  135. if getattr(self, "dense_0", None) is not None:
  136. with tf.name_scope(self.dense_0.name):
  137. self.dense_0.build([None, None, self.d_model_size])
  138. if getattr(self, "dense_2", None) is not None:
  139. with tf.name_scope(self.dense_2.name):
  140. self.dense_2.build([None, None, self.dff])
  141. class TFEncoderLayer(keras.layers.Layer):
  142. def __init__(
  143. self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
  144. ):
  145. super().__init__(**kwargs)
  146. self.output_attentions = output_attentions
  147. self.multi_head_attention = TFMultiHeadAttention(
  148. d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
  149. )
  150. self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
  151. self.layernorm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
  152. self.layernorm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
  153. self.dropout1 = keras.layers.Dropout(rate)
  154. self.dropout2 = keras.layers.Dropout(rate)
  155. self.d_model_size = d_model_size
  156. def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
  157. normed = self.layernorm1(x)
  158. attn_outputs = self.multi_head_attention(
  159. normed,
  160. normed,
  161. normed,
  162. mask,
  163. layer_past,
  164. attention_mask,
  165. head_mask,
  166. use_cache,
  167. output_attentions,
  168. training=training,
  169. )
  170. attn_output = attn_outputs[0]
  171. attn_output = self.dropout1(attn_output, training=training)
  172. out1 = x + attn_output
  173. out2 = self.layernorm2(out1)
  174. ffn_output = self.ffn(out2)
  175. ffn_output = self.dropout2(ffn_output, training=training)
  176. out2 = out1 + ffn_output
  177. outputs = (out2,) + attn_outputs[1:]
  178. return outputs
  179. def build(self, input_shape=None):
  180. if self.built:
  181. return
  182. self.built = True
  183. if getattr(self, "multi_head_attention", None) is not None:
  184. with tf.name_scope(self.multi_head_attention.name):
  185. self.multi_head_attention.build(None)
  186. if getattr(self, "ffn", None) is not None:
  187. with tf.name_scope(self.ffn.name):
  188. self.ffn.build(None)
  189. if getattr(self, "layernorm1", None) is not None:
  190. with tf.name_scope(self.layernorm1.name):
  191. self.layernorm1.build([None, None, self.d_model_size])
  192. if getattr(self, "layernorm2", None) is not None:
  193. with tf.name_scope(self.layernorm2.name):
  194. self.layernorm2.build([None, None, self.d_model_size])
  195. @keras_serializable
  196. class TFCTRLMainLayer(keras.layers.Layer):
  197. config_class = CTRLConfig
  198. def __init__(self, config, **kwargs):
  199. super().__init__(**kwargs)
  200. self.config = config
  201. self.output_hidden_states = config.output_hidden_states
  202. self.output_attentions = config.output_attentions
  203. self.use_cache = config.use_cache
  204. self.return_dict = config.use_return_dict
  205. self.d_model_size = config.n_embd
  206. self.num_layers = config.n_layer
  207. self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
  208. self.w = keras.layers.Embedding(
  209. input_dim=config.vocab_size,
  210. output_dim=config.n_embd,
  211. embeddings_initializer=get_initializer(config.initializer_range),
  212. name="w",
  213. )
  214. self.dropout = keras.layers.Dropout(config.embd_pdrop)
  215. self.h = [
  216. TFEncoderLayer(
  217. config.n_embd,
  218. config.n_head,
  219. config.dff,
  220. config.resid_pdrop,
  221. config.layer_norm_epsilon,
  222. self.output_attentions,
  223. name=f"h_._{i}",
  224. )
  225. for i in range(config.n_layer)
  226. ]
  227. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")
  228. def get_input_embeddings(self):
  229. return self.w
  230. def set_input_embeddings(self, new_embeddings):
  231. self.w = new_embeddings
  232. def _prune_heads(self, heads_to_prune):
  233. """
  234. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  235. """
  236. raise NotImplementedError
  237. @unpack_inputs
  238. def call(
  239. self,
  240. input_ids: TFModelInputType | None = None,
  241. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  242. attention_mask: np.ndarray | tf.Tensor | None = None,
  243. token_type_ids: np.ndarray | tf.Tensor | None = None,
  244. position_ids: np.ndarray | tf.Tensor | None = None,
  245. head_mask: np.ndarray | tf.Tensor | None = None,
  246. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  247. use_cache: Optional[bool] = None,
  248. output_attentions: Optional[bool] = None,
  249. output_hidden_states: Optional[bool] = None,
  250. return_dict: Optional[bool] = None,
  251. training: Optional[bool] = False,
  252. ) -> Union[Tuple, TFBaseModelOutputWithPast]:
  253. # If using past key value states, only the last tokens
  254. # should be given as an input
  255. if past_key_values is not None:
  256. if input_ids is not None:
  257. input_ids = input_ids[:, -1:]
  258. if inputs_embeds is not None:
  259. inputs_embeds = inputs_embeds[:, -1:]
  260. if token_type_ids is not None:
  261. token_type_ids = token_type_ids[:, -1:]
  262. if input_ids is not None and inputs_embeds is not None:
  263. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  264. elif input_ids is not None:
  265. input_shape = shape_list(input_ids)
  266. input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
  267. elif inputs_embeds is not None:
  268. input_shape = shape_list(inputs_embeds)[:-1]
  269. else:
  270. raise ValueError("You have to specify either input_ids or inputs_embeds")
  271. if past_key_values is None:
  272. past_length = 0
  273. past_key_values = [None] * len(self.h)
  274. else:
  275. past_length = shape_list(past_key_values[0][0])[-2]
  276. if position_ids is None:
  277. position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0)
  278. position_ids = tf.tile(position_ids, [input_shape[0], 1])
  279. # Attention mask.
  280. if attention_mask is not None:
  281. # We create a 3D attention mask from a 2D tensor mask.
  282. # Sizes are [batch_size, 1, 1, to_seq_length]
  283. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  284. # this attention mask is more simple than the triangular masking of causal attention
  285. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  286. attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length))
  287. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  288. # masked positions, this operation will create a tensor which is 0.0 for
  289. # positions we want to attend and -10000.0 for masked positions.
  290. # Since we are adding it to the raw scores before the softmax, this is
  291. # effectively the same as removing these entirely.
  292. one_cst = tf.constant(1.0)
  293. ten_thousand_cst = tf.constant(-10000.0)
  294. attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype)
  295. attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), ten_thousand_cst)
  296. # Prepare head mask if needed
  297. # 1.0 in head_mask indicate we keep the head
  298. # attention_probs has shape bsz x n_heads x N x N
  299. # head_mask has shape n_layer x batch x n_heads x N x N
  300. if head_mask is not None:
  301. raise NotImplementedError
  302. else:
  303. head_mask = [None] * self.num_layers
  304. if token_type_ids is not None:
  305. token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])
  306. token_type_embeds = self.w(token_type_ids)
  307. token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
  308. else:
  309. token_type_embeds = tf.constant(0.0)
  310. position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])
  311. if inputs_embeds is None:
  312. check_embeddings_within_bounds(input_ids, self.w.input_dim)
  313. inputs_embeds = self.w(input_ids)
  314. seq_len = input_shape[-1]
  315. mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
  316. inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, inputs_embeds.dtype))
  317. pos_embeds = tf.gather(self.pos_encoding, position_ids)
  318. pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
  319. hidden_states = inputs_embeds + pos_embeds + token_type_embeds
  320. hidden_states = self.dropout(hidden_states, training=training)
  321. output_shape = input_shape + [shape_list(hidden_states)[-1]]
  322. presents = () if use_cache else None
  323. all_hidden_states = () if output_hidden_states else None
  324. all_attentions = () if output_attentions else None
  325. for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
  326. if output_hidden_states:
  327. all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
  328. outputs = h(
  329. hidden_states,
  330. mask,
  331. layer_past,
  332. attention_mask,
  333. head_mask[i],
  334. use_cache,
  335. output_attentions,
  336. training=training,
  337. )
  338. hidden_states, present = outputs[:2]
  339. if use_cache:
  340. presents = presents + (present,)
  341. if output_attentions:
  342. all_attentions = all_attentions + (outputs[2],)
  343. hidden_states = self.layernorm(hidden_states)
  344. hidden_states = tf.reshape(hidden_states, output_shape)
  345. if output_hidden_states:
  346. all_hidden_states = all_hidden_states + (hidden_states,)
  347. if output_attentions:
  348. # let the number of heads free (-1) so we can extract attention even after head pruning
  349. attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
  350. all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
  351. if not return_dict:
  352. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
  353. return TFBaseModelOutputWithPast(
  354. last_hidden_state=hidden_states,
  355. past_key_values=presents,
  356. hidden_states=all_hidden_states,
  357. attentions=all_attentions,
  358. )
  359. def build(self, input_shape=None):
  360. if self.built:
  361. return
  362. self.built = True
  363. if getattr(self, "w", None) is not None:
  364. with tf.name_scope(self.w.name):
  365. self.w.build(None)
  366. if getattr(self, "layernorm", None) is not None:
  367. with tf.name_scope(self.layernorm.name):
  368. self.layernorm.build([None, None, self.config.n_embd])
  369. if getattr(self, "h", None) is not None:
  370. for layer in self.h:
  371. with tf.name_scope(layer.name):
  372. layer.build(None)
  373. class TFCTRLPreTrainedModel(TFPreTrainedModel):
  374. """
  375. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  376. models.
  377. """
  378. config_class = CTRLConfig
  379. base_model_prefix = "transformer"
  380. CTRL_START_DOCSTRING = r"""
  381. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  382. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  383. etc.)
  384. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  385. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  386. behavior.
  387. <Tip>
  388. TensorFlow models and layers in `transformers` accept two formats as input:
  389. - having all inputs as keyword arguments (like PyTorch models), or
  390. - having all inputs as a list, tuple or dict in the first positional argument.
  391. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  392. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  393. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  394. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  395. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  396. positional argument:
  397. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  398. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  399. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  400. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  401. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  402. Note that when creating models and layers with
  403. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  404. about any of this, as you can just pass inputs like you would to any other Python function!
  405. </Tip>
  406. Parameters:
  407. config ([`CTRLConfig`]): Model configuration class with all the parameters of the model.
  408. Initializing with a config file does not load the weights associated with the model, only the
  409. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  410. """
  411. CTRL_INPUTS_DOCSTRING = r"""
  412. Args:
  413. input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, input_ids_length)`):
  414. `input_ids_length` = `sequence_length` if `past` is `None` else `past[0].shape[-2]` (`sequence_length` of
  415. input past key value states).
  416. Indices of input sequence tokens in the vocabulary.
  417. If `past` is used, only input IDs that do not have their past calculated should be passed as `input_ids`.
  418. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  419. [`PreTrainedTokenizer.encode`] for details.
  420. [What are input IDs?](../glossary#input-ids)
  421. past (`List[tf.Tensor]` of length `config.n_layers`):
  422. Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
  423. `past` output below). Can be used to speed up sequential decoding. The token ids which have their past
  424. given to this model should not be passed as input ids as they have already been computed.
  425. attention_mask (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
  426. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  427. - 1 for tokens that are **not masked**,
  428. - 0 for tokens that are **masked**.
  429. [What are attention masks?](../glossary#attention-mask)
  430. token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
  431. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  432. 1]`:
  433. - 0 corresponds to a *sentence A* token,
  434. - 1 corresponds to a *sentence B* token.
  435. [What are token type IDs?](../glossary#token-type-ids)
  436. position_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
  437. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  438. config.max_position_embeddings - 1]`.
  439. [What are position IDs?](../glossary#position-ids)
  440. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  441. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  442. - 1 indicates the head is **not masked**,
  443. - 0 indicates the head is **masked**.
  444. inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  445. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  446. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  447. model's internal embedding lookup matrix.
  448. use_cache (`bool`, *optional*):
  449. If set to `True`, `past` key value states are returned and can be used to speed up decoding (see `past`).
  450. output_attentions (`bool`, *optional*):
  451. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  452. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  453. config will be used instead.
  454. output_hidden_states (`bool`, *optional*):
  455. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  456. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  457. used instead.
  458. return_dict (`bool`, *optional*):
  459. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  460. eager mode, in graph mode the value will always be set to True.
  461. training (`bool`, *optional*, defaults to `False`):
  462. Whether or not to use the model in training mode (some modules like dropout modules have different
  463. behaviors between training and evaluation).
  464. """
  465. @add_start_docstrings(
  466. "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
  467. CTRL_START_DOCSTRING,
  468. )
  469. class TFCTRLModel(TFCTRLPreTrainedModel):
  470. def __init__(self, config, *inputs, **kwargs):
  471. super().__init__(config, *inputs, **kwargs)
  472. self.transformer = TFCTRLMainLayer(config, name="transformer")
  473. @unpack_inputs
  474. @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
  475. @add_code_sample_docstrings(
  476. checkpoint=_CHECKPOINT_FOR_DOC,
  477. output_type=TFBaseModelOutputWithPast,
  478. config_class=_CONFIG_FOR_DOC,
  479. )
  480. def call(
  481. self,
  482. input_ids: TFModelInputType | None = None,
  483. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  484. attention_mask: np.ndarray | tf.Tensor | None = None,
  485. token_type_ids: np.ndarray | tf.Tensor | None = None,
  486. position_ids: np.ndarray | tf.Tensor | None = None,
  487. head_mask: np.ndarray | tf.Tensor | None = None,
  488. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  489. use_cache: Optional[bool] = None,
  490. output_attentions: Optional[bool] = None,
  491. output_hidden_states: Optional[bool] = None,
  492. return_dict: Optional[bool] = None,
  493. training: Optional[bool] = False,
  494. ) -> Union[Tuple, TFBaseModelOutputWithPast]:
  495. outputs = self.transformer(
  496. input_ids=input_ids,
  497. past_key_values=past_key_values,
  498. attention_mask=attention_mask,
  499. token_type_ids=token_type_ids,
  500. position_ids=position_ids,
  501. head_mask=head_mask,
  502. inputs_embeds=inputs_embeds,
  503. use_cache=use_cache,
  504. output_attentions=output_attentions,
  505. output_hidden_states=output_hidden_states,
  506. return_dict=return_dict,
  507. training=training,
  508. )
  509. return outputs
  510. def build(self, input_shape=None):
  511. if self.built:
  512. return
  513. self.built = True
  514. if getattr(self, "transformer", None) is not None:
  515. with tf.name_scope(self.transformer.name):
  516. self.transformer.build(None)
  517. class TFCTRLBiasLayer(keras.layers.Layer):
  518. """
  519. Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
  520. so all weights have to be registered in a layer.
  521. """
  522. def __init__(self, shape, initializer, trainable, name, **kwargs):
  523. super().__init__(name=name, **kwargs)
  524. self.shape = shape
  525. self.initializer = initializer
  526. self.trainable = trainable
  527. def build(self, input_shape):
  528. self.bias = self.add_weight(
  529. name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable
  530. )
  531. super().build(input_shape)
  532. def call(self, x):
  533. return x + self.bias
  534. @add_start_docstrings(
  535. """
  536. The CTRL Model transformer with a language modeling head on top (linear layer with weights tied to the input
  537. embeddings).
  538. """,
  539. CTRL_START_DOCSTRING,
  540. )
  541. class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
  542. def __init__(self, config, *inputs, **kwargs):
  543. super().__init__(config, *inputs, **kwargs)
  544. self.transformer = TFCTRLMainLayer(config, name="transformer")
  545. self.bias_layer = TFCTRLBiasLayer(
  546. name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True
  547. )
  548. def get_output_embeddings(self):
  549. return self.get_input_embeddings()
  550. def set_output_embeddings(self, value):
  551. self.set_input_embeddings(value)
  552. def get_bias(self):
  553. return {"lm_head.bias": self.bias_layer.bias}
  554. def set_bias(self, value):
  555. # Replaces the existing layers containing bias for correct (de)serialization.
  556. vocab_size = value["lm_head.bias"].shape[-1]
  557. self.bias_layer = TFCTRLBiasLayer(
  558. name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True
  559. )
  560. self.bias_layer.build(None)
  561. self.bias_layer.bias.assign(value["lm_head.bias"])
  562. # Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation
  563. def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
  564. token_type_ids = kwargs.get("token_type_ids", None)
  565. # only last token for inputs_ids if past is defined in kwargs
  566. if past_key_values:
  567. inputs = tf.expand_dims(inputs[:, -1], -1)
  568. if token_type_ids is not None:
  569. token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1)
  570. position_ids = kwargs.get("position_ids", None)
  571. attention_mask = kwargs.get("attention_mask", None)
  572. if attention_mask is not None and position_ids is None:
  573. position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
  574. if past_key_values:
  575. position_ids = tf.expand_dims(position_ids[:, -1], -1)
  576. return {
  577. "input_ids": inputs,
  578. "attention_mask": attention_mask,
  579. "position_ids": position_ids,
  580. "past_key_values": past_key_values,
  581. "use_cache": use_cache,
  582. "token_type_ids": token_type_ids,
  583. }
  584. @unpack_inputs
  585. @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
  586. @add_code_sample_docstrings(
  587. checkpoint=_CHECKPOINT_FOR_DOC,
  588. output_type=TFCausalLMOutputWithPast,
  589. config_class=_CONFIG_FOR_DOC,
  590. )
  591. def call(
  592. self,
  593. input_ids: TFModelInputType | None = None,
  594. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  595. attention_mask: np.ndarray | tf.Tensor | None = None,
  596. token_type_ids: np.ndarray | tf.Tensor | None = None,
  597. position_ids: np.ndarray | tf.Tensor | None = None,
  598. head_mask: np.ndarray | tf.Tensor | None = None,
  599. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  600. use_cache: Optional[bool] = None,
  601. output_attentions: Optional[bool] = None,
  602. output_hidden_states: Optional[bool] = None,
  603. return_dict: Optional[bool] = None,
  604. labels: np.ndarray | tf.Tensor | None = None,
  605. training: Optional[bool] = False,
  606. ) -> Union[Tuple, TFCausalLMOutputWithPast]:
  607. r"""
  608. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  609. Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
  610. config.vocab_size - 1]`.
  611. """
  612. transformer_outputs = self.transformer(
  613. input_ids=input_ids,
  614. past_key_values=past_key_values,
  615. attention_mask=attention_mask,
  616. token_type_ids=token_type_ids,
  617. position_ids=position_ids,
  618. head_mask=head_mask,
  619. inputs_embeds=inputs_embeds,
  620. use_cache=use_cache,
  621. output_attentions=output_attentions,
  622. output_hidden_states=output_hidden_states,
  623. return_dict=return_dict,
  624. training=training,
  625. )
  626. hidden_states = transformer_outputs[0]
  627. logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True)
  628. logits = self.bias_layer(logits)
  629. loss = None
  630. if labels is not None:
  631. # shift labels to the left and cut last logit token
  632. shifted_logits = logits[:, :-1]
  633. labels = labels[:, 1:]
  634. loss = self.hf_compute_loss(labels, shifted_logits)
  635. if not return_dict:
  636. output = (logits,) + transformer_outputs[1:]
  637. return ((loss,) + output) if loss is not None else output
  638. return TFCausalLMOutputWithPast(
  639. loss=loss,
  640. logits=logits,
  641. past_key_values=transformer_outputs.past_key_values,
  642. hidden_states=transformer_outputs.hidden_states,
  643. attentions=transformer_outputs.attentions,
  644. )
  645. def build(self, input_shape=None):
  646. if self.built:
  647. return
  648. self.built = True
  649. if getattr(self, "transformer", None) is not None:
  650. with tf.name_scope(self.transformer.name):
  651. self.transformer.build(None)
  652. if getattr(self, "bias_layer", None) is not None:
  653. with tf.name_scope(self.bias_layer.name):
  654. self.bias_layer.build(None)
  655. @add_start_docstrings(
  656. """
  657. The CTRL Model transformer with a sequence classification head on top (linear layer).
  658. [`TFCTRLForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  659. (e.g. GPT-1, GPT-2) do.
  660. Since it does classification on the last token, it requires to know the position of the last token. If a
  661. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  662. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  663. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  664. each row of the batch).
  665. """,
  666. CTRL_START_DOCSTRING,
  667. )
  668. class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassificationLoss):
  669. def __init__(self, config, *inputs, **kwargs):
  670. super().__init__(config, *inputs, **kwargs)
  671. self.num_labels = config.num_labels
  672. self.classifier = keras.layers.Dense(
  673. config.num_labels,
  674. kernel_initializer=get_initializer(config.initializer_range),
  675. name="classifier",
  676. use_bias=False,
  677. )
  678. self.transformer = TFCTRLMainLayer(config, name="transformer")
  679. self.config = config
  680. def get_output_embeddings(self):
  681. # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too.
  682. logger.warning(
  683. "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed "
  684. "in transformers v4.32."
  685. )
  686. return self.transformer.w
  687. @unpack_inputs
  688. @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
  689. @add_code_sample_docstrings(
  690. checkpoint=_CHECKPOINT_FOR_DOC,
  691. output_type=TFSequenceClassifierOutput,
  692. config_class=_CONFIG_FOR_DOC,
  693. )
  694. def call(
  695. self,
  696. input_ids: TFModelInputType | None = None,
  697. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  698. attention_mask: np.ndarray | tf.Tensor | None = None,
  699. token_type_ids: np.ndarray | tf.Tensor | None = None,
  700. position_ids: np.ndarray | tf.Tensor | None = None,
  701. head_mask: np.ndarray | tf.Tensor | None = None,
  702. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  703. use_cache: Optional[bool] = None,
  704. output_attentions: Optional[bool] = None,
  705. output_hidden_states: Optional[bool] = None,
  706. return_dict: Optional[bool] = None,
  707. labels: np.ndarray | tf.Tensor | None = None,
  708. training: Optional[bool] = False,
  709. ) -> Union[Tuple, TFSequenceClassifierOutput]:
  710. r"""
  711. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  712. Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
  713. config.vocab_size - 1]`.
  714. """
  715. transformer_outputs = self.transformer(
  716. input_ids=input_ids,
  717. past_key_values=past_key_values,
  718. attention_mask=attention_mask,
  719. token_type_ids=token_type_ids,
  720. position_ids=position_ids,
  721. head_mask=head_mask,
  722. inputs_embeds=inputs_embeds,
  723. use_cache=use_cache,
  724. output_attentions=output_attentions,
  725. output_hidden_states=output_hidden_states,
  726. return_dict=return_dict,
  727. training=training,
  728. )
  729. hidden_states = transformer_outputs[0]
  730. logits = self.classifier(hidden_states)
  731. in_logits = None
  732. if self.config.pad_token_id is None:
  733. sequence_lengths = -1
  734. else:
  735. if input_ids is not None:
  736. sequence_lengths = (
  737. tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
  738. - 1
  739. )
  740. sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
  741. in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
  742. else:
  743. sequence_lengths = -1
  744. logger.warning_once(
  745. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  746. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  747. )
  748. loss = None
  749. if labels is not None:
  750. if input_ids is not None:
  751. batch_size, sequence_length = shape_list(input_ids)[:2]
  752. else:
  753. batch_size, sequence_length = shape_list(inputs_embeds)[:2]
  754. if self.config.pad_token_id is None and batch_size != 1:
  755. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  756. if not tf.is_tensor(sequence_lengths):
  757. in_logits = logits[0:batch_size, sequence_lengths]
  758. loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
  759. pooled_logits = in_logits if in_logits is not None else logits
  760. if not return_dict:
  761. output = (pooled_logits,) + transformer_outputs[1:]
  762. return ((loss,) + output) if loss is not None else output
  763. return TFSequenceClassifierOutput(
  764. loss=loss,
  765. logits=pooled_logits,
  766. hidden_states=transformer_outputs.hidden_states,
  767. attentions=transformer_outputs.attentions,
  768. )
  769. def build(self, input_shape=None):
  770. if self.built:
  771. return
  772. self.built = True
  773. if getattr(self, "classifier", None) is not None:
  774. with tf.name_scope(self.classifier.name):
  775. self.classifier.build([None, None, self.config.n_embd])
  776. if getattr(self, "transformer", None) is not None:
  777. with tf.name_scope(self.transformer.name):
  778. self.transformer.build(None)