modeling_tf_opt.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094
  1. # coding=utf-8
  2. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF 2.0 OPT model."""
  16. from __future__ import annotations
  17. from typing import Optional, Tuple, Union
  18. import numpy as np
  19. import tensorflow as tf
  20. from ...activations_tf import get_tf_activation
  21. from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast
  22. # Public API
  23. from ...modeling_tf_utils import (
  24. TFCausalLanguageModelingLoss,
  25. TFModelInputType,
  26. TFPreTrainedModel,
  27. TFSharedEmbeddings,
  28. keras,
  29. keras_serializable,
  30. unpack_inputs,
  31. )
  32. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  33. from ...utils import (
  34. add_code_sample_docstrings,
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. logging,
  38. replace_return_docstrings,
  39. )
  40. from .configuration_opt import OPTConfig
  41. logger = logging.get_logger(__name__)
  42. _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
  43. _CONFIG_FOR_DOC = "OPTConfig"
  44. # Base model docstring
  45. _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
  46. # Causal LM output
  47. _CAUSAL_LM_EXPECTED_OUTPUT = (
  48. "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
  49. )
  50. LARGE_NEGATIVE = -1e8
  51. def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
  52. """
  53. Make causal mask used for bi-directional self-attention.
  54. """
  55. bsz = input_ids_shape[0]
  56. tgt_len = input_ids_shape[1]
  57. # We need triu with k = 1 but TF expects known compile-time dims for that, so we hack around it
  58. mask = tf.fill((tgt_len, tgt_len), tf.cast(LARGE_NEGATIVE, tf.float32))
  59. mask = tf.linalg.band_part(mask, 0, -1) - tf.linalg.band_part(mask, 0, 0)
  60. if past_key_values_length > 0:
  61. mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
  62. return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
  63. # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
  64. def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
  65. """
  66. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  67. """
  68. src_len = shape_list(mask)[1]
  69. tgt_len = tgt_len if tgt_len is not None else src_len
  70. one_cst = tf.constant(1.0)
  71. mask = tf.cast(mask, dtype=one_cst.dtype)
  72. expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
  73. return (one_cst - expanded_mask) * LARGE_NEGATIVE
  74. class TFOPTLearnedPositionalEmbedding(keras.layers.Embedding):
  75. """
  76. This module learns positional embeddings up to a fixed maximum size.
  77. """
  78. def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
  79. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  80. # and adjust num_embeddings appropriately. Other models don't have this hack
  81. self.offset = 2
  82. super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)
  83. def call(self, attention_mask, past_key_values_length: int = 0):
  84. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  85. attention_mask = tf.cast(attention_mask, tf.int64)
  86. # create positions depending on attention_mask
  87. positions = tf.math.cumsum(attention_mask, axis=1) * attention_mask - 1
  88. # cut positions if `past_key_values_length` is > 0
  89. positions = positions[:, past_key_values_length:]
  90. return super().call(positions + self.offset)
  91. # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->OPT
  92. class TFOPTAttention(keras.layers.Layer):
  93. """Multi-headed attention from "Attention Is All You Need"""
  94. def __init__(
  95. self,
  96. embed_dim: int,
  97. num_heads: int,
  98. dropout: float = 0.0,
  99. is_decoder: bool = False,
  100. bias: bool = True,
  101. **kwargs,
  102. ):
  103. super().__init__(**kwargs)
  104. self.embed_dim = embed_dim
  105. self.num_heads = num_heads
  106. self.dropout = keras.layers.Dropout(dropout)
  107. self.head_dim = embed_dim // num_heads
  108. if (self.head_dim * num_heads) != self.embed_dim:
  109. raise ValueError(
  110. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  111. f" and `num_heads`: {num_heads})."
  112. )
  113. self.scaling = self.head_dim**-0.5
  114. self.is_decoder = is_decoder
  115. self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
  116. self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
  117. self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
  118. self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
  119. def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
  120. return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
  121. def call(
  122. self,
  123. hidden_states: tf.Tensor,
  124. key_value_states: tf.Tensor | None = None,
  125. past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
  126. attention_mask: tf.Tensor | None = None,
  127. layer_head_mask: tf.Tensor | None = None,
  128. training: Optional[bool] = False,
  129. ) -> Tuple[tf.Tensor, tf.Tensor | None]:
  130. """Input shape: Batch x Time x Channel"""
  131. # if key_value_states are provided this layer is used as a cross-attention layer
  132. # for the decoder
  133. is_cross_attention = key_value_states is not None
  134. bsz, tgt_len, embed_dim = shape_list(hidden_states)
  135. # get query proj
  136. query_states = self.q_proj(hidden_states) * self.scaling
  137. # get key, value proj
  138. if is_cross_attention and past_key_value is not None:
  139. # reuse k,v, cross_attentions
  140. key_states = past_key_value[0]
  141. value_states = past_key_value[1]
  142. elif is_cross_attention:
  143. # cross_attentions
  144. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  145. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  146. elif past_key_value is not None:
  147. # reuse k, v, self_attention
  148. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  149. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  150. key_states = tf.concat([past_key_value[0], key_states], axis=2)
  151. value_states = tf.concat([past_key_value[1], value_states], axis=2)
  152. else:
  153. # self_attention
  154. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  155. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  156. if self.is_decoder:
  157. # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
  158. # Further calls to cross_attention layer can then reuse all cross-attention
  159. # key/value_states (first "if" case)
  160. # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
  161. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  162. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  163. # if encoder bi-directional self-attention `past_key_value` is always `None`
  164. past_key_value = (key_states, value_states)
  165. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  166. query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
  167. key_states = tf.reshape(key_states, proj_shape)
  168. value_states = tf.reshape(value_states, proj_shape)
  169. src_len = shape_list(key_states)[1]
  170. attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
  171. tf.debugging.assert_equal(
  172. shape_list(attn_weights),
  173. [bsz * self.num_heads, tgt_len, src_len],
  174. message=(
  175. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  176. f" {shape_list(attn_weights)}"
  177. ),
  178. )
  179. if attention_mask is not None:
  180. tf.debugging.assert_equal(
  181. shape_list(attention_mask),
  182. [bsz, 1, tgt_len, src_len],
  183. message=(
  184. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  185. f" {shape_list(attention_mask)}"
  186. ),
  187. )
  188. attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
  189. attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
  190. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  191. attn_weights = stable_softmax(attn_weights, axis=-1)
  192. if layer_head_mask is not None:
  193. tf.debugging.assert_equal(
  194. shape_list(layer_head_mask),
  195. [self.num_heads],
  196. message=(
  197. f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
  198. f" {shape_list(layer_head_mask)}"
  199. ),
  200. )
  201. attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
  202. attn_weights, (bsz, self.num_heads, tgt_len, src_len)
  203. )
  204. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  205. attn_probs = self.dropout(attn_weights, training=training)
  206. attn_output = tf.matmul(attn_probs, value_states)
  207. tf.debugging.assert_equal(
  208. shape_list(attn_output),
  209. [bsz * self.num_heads, tgt_len, self.head_dim],
  210. message=(
  211. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  212. f" {shape_list(attn_output)}"
  213. ),
  214. )
  215. attn_output = tf.transpose(
  216. tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
  217. )
  218. attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
  219. attn_output = self.out_proj(attn_output)
  220. attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
  221. return attn_output, attn_weights, past_key_value
  222. def build(self, input_shape=None):
  223. if self.built:
  224. return
  225. self.built = True
  226. if getattr(self, "k_proj", None) is not None:
  227. with tf.name_scope(self.k_proj.name):
  228. self.k_proj.build([None, None, self.embed_dim])
  229. if getattr(self, "q_proj", None) is not None:
  230. with tf.name_scope(self.q_proj.name):
  231. self.q_proj.build([None, None, self.embed_dim])
  232. if getattr(self, "v_proj", None) is not None:
  233. with tf.name_scope(self.v_proj.name):
  234. self.v_proj.build([None, None, self.embed_dim])
  235. if getattr(self, "out_proj", None) is not None:
  236. with tf.name_scope(self.out_proj.name):
  237. self.out_proj.build([None, None, self.embed_dim])
  238. class TFOPTDecoderLayer(keras.layers.Layer):
  239. def __init__(self, config: OPTConfig, **kwargs):
  240. super().__init__(**kwargs)
  241. self.do_layer_norm_before = config.do_layer_norm_before
  242. self.embed_dim = config.hidden_size
  243. self.self_attn = TFOPTAttention(
  244. embed_dim=self.embed_dim,
  245. num_heads=config.num_attention_heads,
  246. dropout=config.attention_dropout,
  247. name="self_attn",
  248. is_decoder=True,
  249. )
  250. self.dropout = keras.layers.Dropout(config.dropout)
  251. self.activation_fn = get_tf_activation(config.activation_function)
  252. self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
  253. self.fc1 = keras.layers.Dense(config.ffn_dim, name="fc1")
  254. self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
  255. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
  256. self.config = config
  257. def call(
  258. self,
  259. hidden_states: tf.Tensor,
  260. attention_mask: np.ndarray | tf.Tensor | None = None,
  261. layer_head_mask: tf.Tensor | None = None,
  262. past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  263. training: Optional[bool] = False,
  264. output_attentions: Optional[bool] = False,
  265. use_cache: Optional[bool] = False,
  266. ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
  267. """
  268. Args:
  269. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  270. attention_mask (`tf.Tensor`, *optional*): attention mask of size
  271. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  272. layer_head_mask (`tf.Tensor`, *optional*): mask for attention heads in a given layer of size
  273. `(decoder_attention_heads,)`
  274. past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
  275. training (`bool`, *optional*, defaults to `False`):
  276. Whether or not to use the model in training mode (some modules like dropout modules have different
  277. behaviors between training and evaluation).
  278. """
  279. residual = hidden_states
  280. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  281. if self.do_layer_norm_before:
  282. hidden_states = self.self_attn_layer_norm(hidden_states)
  283. # Self Attention
  284. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  285. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  286. # add present self-attn cache to positions 1,2 of present_key_value tuple
  287. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  288. hidden_states=hidden_states,
  289. past_key_value=self_attn_past_key_value,
  290. attention_mask=attention_mask,
  291. layer_head_mask=layer_head_mask,
  292. )
  293. hidden_states = self.dropout(hidden_states, training=training)
  294. hidden_states = residual + hidden_states
  295. # 350m applies layer norm AFTER attention
  296. if not self.do_layer_norm_before:
  297. hidden_states = self.self_attn_layer_norm(hidden_states)
  298. # Fully Connected
  299. residual = hidden_states
  300. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  301. if self.do_layer_norm_before:
  302. hidden_states = self.final_layer_norm(hidden_states)
  303. hidden_states = self.fc1(hidden_states)
  304. hidden_states = self.activation_fn(hidden_states)
  305. hidden_states = self.fc2(hidden_states)
  306. hidden_states = self.dropout(hidden_states, training=training)
  307. hidden_states = residual + hidden_states
  308. # 350m applies layer norm AFTER attention
  309. if not self.do_layer_norm_before:
  310. hidden_states = self.final_layer_norm(hidden_states)
  311. return (hidden_states, self_attn_weights, present_key_value)
  312. def build(self, input_shape=None):
  313. if self.built:
  314. return
  315. self.built = True
  316. if getattr(self, "self_attn", None) is not None:
  317. with tf.name_scope(self.self_attn.name):
  318. self.self_attn.build(None)
  319. if getattr(self, "self_attn_layer_norm", None) is not None:
  320. with tf.name_scope(self.self_attn_layer_norm.name):
  321. self.self_attn_layer_norm.build([None, None, self.embed_dim])
  322. if getattr(self, "fc1", None) is not None:
  323. with tf.name_scope(self.fc1.name):
  324. self.fc1.build([None, None, self.embed_dim])
  325. if getattr(self, "fc2", None) is not None:
  326. with tf.name_scope(self.fc2.name):
  327. self.fc2.build([None, None, self.config.ffn_dim])
  328. if getattr(self, "final_layer_norm", None) is not None:
  329. with tf.name_scope(self.final_layer_norm.name):
  330. self.final_layer_norm.build([None, None, self.embed_dim])
  331. OPT_START_DOCSTRING = r"""
  332. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  333. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  334. etc.)
  335. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  336. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  337. behavior.
  338. <Tip>
  339. TensorFlow models and layers in `transformers` accept two formats as input:
  340. - having all inputs as keyword arguments (like PyTorch models), or
  341. - having all inputs as a list, tuple or dict in the first positional argument.
  342. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  343. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  344. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  345. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  346. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  347. positional argument:
  348. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  349. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  350. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  351. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  352. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  353. Note that when creating models and layers with
  354. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  355. about any of this, as you can just pass inputs like you would to any other Python function!
  356. </Tip>
  357. Args:
  358. config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
  359. Initializing with a config file does not load the weights associated with the model, only the
  360. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  361. """
  362. @add_start_docstrings(
  363. "The bare OPT Model outputting raw hidden-states without any specific head on top.",
  364. OPT_START_DOCSTRING,
  365. )
  366. class TFOPTPreTrainedModel(TFPreTrainedModel):
  367. """
  368. TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel
  369. Args:
  370. config: OPTConfig
  371. """
  372. config_class = OPTConfig
  373. base_model_prefix = "model"
  374. OPT_INPUTS_DOCSTRING = r"""
  375. Args:
  376. input_ids (`tf.Tensor` of shape `({0})`):
  377. Indices of input sequence tokens in the vocabulary.
  378. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  379. [`PreTrainedTokenizer.__call__`] for details.
  380. [What are input IDs?](../glossary#input-ids)
  381. attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
  382. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  383. - 1 for tokens that are **not masked**,
  384. - 0 for tokens that are **masked**.
  385. [What are attention masks?](../glossary#attention-mask)
  386. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  387. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  388. - 1 indicates the head is **not masked**,
  389. - 0 indicates the head is **masked**.
  390. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
  391. contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  392. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  393. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  394. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  395. use_cache (`bool`, *optional*, defaults to `True`):
  396. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  397. `past_key_values`). Set to `False` during training, `True` during generation
  398. output_attentions (`bool`, *optional*):
  399. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  400. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  401. config will be used instead.
  402. output_hidden_states (`bool`, *optional*):
  403. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  404. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  405. used instead.
  406. return_dict (`bool`, *optional*):
  407. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  408. eager mode, in graph mode the value will always be set to True.
  409. training (`bool`, *optional*, defaults to `False`):
  410. Whether or not to use the model in training mode (some modules like dropout modules have different
  411. behaviors between training and evaluation).
  412. """
  413. @keras_serializable
  414. class TFOPTDecoder(keras.layers.Layer):
  415. config_class = OPTConfig
  416. def __init__(self, config: OPTConfig, **kwargs):
  417. super().__init__(**kwargs)
  418. self.config = config
  419. self.padding_idx = config.pad_token_id
  420. self.layerdrop = config.layerdrop
  421. num_embeddings = config.max_position_embeddings
  422. self.embed_tokens = TFSharedEmbeddings(
  423. config.vocab_size, config.word_embed_proj_dim, config.pad_token_id, name="embed_tokens"
  424. )
  425. self.embed_positions = TFOPTLearnedPositionalEmbedding(
  426. num_embeddings,
  427. config.hidden_size,
  428. name="embed_positions",
  429. )
  430. # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
  431. # with checkpoints that have been fine-tuned before transformers v4.20.1
  432. # see https://github.com/facebookresearch/metaseq/pull/164
  433. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  434. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
  435. else:
  436. self.final_layer_norm = None
  437. if config.word_embed_proj_dim != config.hidden_size:
  438. self.project_out = keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
  439. self.project_in = keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
  440. else:
  441. self.project_in = None
  442. self.project_out = None
  443. self.layers = [TFOPTDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
  444. self.dropout = keras.layers.Dropout(config.dropout)
  445. def get_embed_tokens(self):
  446. return self.embed_tokens
  447. def set_embed_tokens(self, embed_tokens):
  448. self.embed_tokens = embed_tokens
  449. def set_input_embeddings(self, new_embeddings):
  450. self.embed_tokens.vocab_size = new_embeddings.shape[0]
  451. self.embed_tokens.weight = new_embeddings
  452. def get_input_embeddings(self):
  453. return self.embed_tokens
  454. def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length):
  455. # create causal mask
  456. # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  457. _, seq_length = input_shape
  458. tf.debugging.assert_equal(
  459. seq_length + past_key_values_length,
  460. shape_list(attention_mask)[1],
  461. message="Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
  462. f" but is {shape_list(attention_mask)[1]} with input_ids shape {input_shape} and past length"
  463. f" {past_key_values_length}.",
  464. )
  465. expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
  466. if seq_length > 1:
  467. combined_attention_mask = (
  468. _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + expanded_attn_mask
  469. )
  470. else:
  471. combined_attention_mask = expanded_attn_mask
  472. return combined_attention_mask
  473. @unpack_inputs
  474. def call(
  475. self,
  476. input_ids: TFModelInputType | None = None,
  477. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  478. attention_mask: np.ndarray | tf.Tensor | None = None,
  479. head_mask: np.ndarray | tf.Tensor | None = None,
  480. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  481. use_cache: Optional[bool] = None,
  482. output_attentions: Optional[bool] = None,
  483. output_hidden_states: Optional[bool] = None,
  484. return_dict: Optional[bool] = None,
  485. training: Optional[bool] = False,
  486. ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
  487. r"""
  488. Args:
  489. input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
  490. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  491. provide it.
  492. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  493. [`PreTrainedTokenizer.__call__`] for details.
  494. [What are input IDs?](../glossary#input-ids)
  495. attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  496. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  497. - 1 for tokens that are **not masked**,
  498. - 0 for tokens that are **masked**.
  499. [What are attention masks?](../glossary#attention-mask)
  500. head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  501. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  502. - 1 indicates the head is **not masked**,
  503. - 0 indicates the head is **masked**.
  504. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  505. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
  506. decoding.
  507. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  508. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  509. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  510. inputs_embeds (`tf.Tensor` of
  511. shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
  512. `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
  513. control over how to convert `input_ids` indices into associated vectors than the model's internal
  514. embedding lookup matrix.
  515. output_attentions (`bool`, *optional*):
  516. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  517. returned tensors for more detail.
  518. output_hidden_states (`bool`, *optional*):
  519. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  520. for more detail.
  521. return_dict (`bool`, *optional*):
  522. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  523. training (`bool`, *optional*, defaults to `False`):
  524. Whether or not to use the model in training mode (some modules like dropout modules have different
  525. behaviors between training and evaluation).
  526. """
  527. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  528. output_hidden_states = (
  529. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  530. )
  531. use_cache = use_cache if use_cache is not None else self.config.use_cache
  532. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  533. if input_ids is not None and inputs_embeds is not None:
  534. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  535. elif input_ids is not None:
  536. input_shape = shape_list(input_ids)
  537. elif inputs_embeds is not None:
  538. input_shape = shape_list(inputs_embeds)[:-1]
  539. else:
  540. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  541. past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
  542. if inputs_embeds is None:
  543. check_embeddings_within_bounds(input_ids, self.embed_tokens.vocab_size)
  544. inputs_embeds = self.embed_tokens(input_ids)
  545. if attention_mask is None:
  546. attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.bool)
  547. else:
  548. tf.debugging.assert_equal(
  549. shape_list(attention_mask)[1],
  550. past_key_values_length + input_shape[1],
  551. message=(
  552. f"The provided attention mask has length {tf.shape(attention_mask)[1]}, but its length should be "
  553. f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)"
  554. ),
  555. )
  556. pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
  557. attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length)
  558. if self.project_in is not None:
  559. inputs_embeds = self.project_in(inputs_embeds)
  560. hidden_states = inputs_embeds + pos_embeds
  561. # decoder layers
  562. all_hidden_states = () if output_hidden_states else None
  563. all_self_attns = () if output_attentions else None
  564. present_key_values = () if use_cache else None
  565. # check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
  566. for attn_mask_name, attn_mask in [("head_mask", head_mask)]:
  567. if attn_mask is not None:
  568. tf.debugging.assert_equal(
  569. shape_list(attn_mask)[0],
  570. len(self.layers),
  571. message=(
  572. f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for"
  573. f" {shape_list(attn_mask)[0]}."
  574. ),
  575. )
  576. for idx, decoder_layer in enumerate(self.layers):
  577. if output_hidden_states:
  578. all_hidden_states += (hidden_states,)
  579. past_key_value = past_key_values[idx] if past_key_values is not None else None
  580. hidden_states, layer_self_attn, present_key_value = decoder_layer(
  581. hidden_states,
  582. attention_mask=attention_mask,
  583. layer_head_mask=head_mask[idx] if head_mask is not None else None,
  584. past_key_value=past_key_value,
  585. )
  586. if use_cache:
  587. present_key_values += (present_key_value,)
  588. if output_attentions:
  589. all_self_attns += (layer_self_attn,)
  590. if self.final_layer_norm is not None:
  591. hidden_states = self.final_layer_norm(hidden_states)
  592. if self.project_out is not None:
  593. hidden_states = self.project_out(hidden_states)
  594. if output_hidden_states:
  595. all_hidden_states += (hidden_states,)
  596. if not return_dict:
  597. return tuple(
  598. v for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns] if v is not None
  599. )
  600. else:
  601. return TFBaseModelOutputWithPast(
  602. last_hidden_state=hidden_states,
  603. past_key_values=present_key_values,
  604. hidden_states=all_hidden_states,
  605. attentions=all_self_attns,
  606. )
  607. def build(self, input_shape=None):
  608. if self.built:
  609. return
  610. self.built = True
  611. if getattr(self, "embed_tokens", None) is not None:
  612. with tf.name_scope(self.embed_tokens.name):
  613. self.embed_tokens.build(None)
  614. if getattr(self, "embed_positions", None) is not None:
  615. with tf.name_scope(self.embed_positions.name):
  616. self.embed_positions.build(None)
  617. if getattr(self, "final_layer_norm", None) is not None:
  618. with tf.name_scope(self.final_layer_norm.name):
  619. self.final_layer_norm.build([None, None, self.config.hidden_size])
  620. if getattr(self, "project_out", None) is not None:
  621. with tf.name_scope(self.project_out.name):
  622. self.project_out.build([None, None, self.config.hidden_size])
  623. if getattr(self, "project_in", None) is not None:
  624. with tf.name_scope(self.project_in.name):
  625. self.project_in.build([None, None, self.config.word_embed_proj_dim])
  626. if getattr(self, "layers", None) is not None:
  627. for layer in self.layers:
  628. with tf.name_scope(layer.name):
  629. layer.build(None)
  630. @keras_serializable
  631. class TFOPTMainLayer(keras.layers.Layer):
  632. config_class = OPTConfig
  633. def __init__(self, config: OPTConfig, **kwargs):
  634. super().__init__(**kwargs)
  635. self.config = config
  636. self.decoder = TFOPTDecoder(config, name="decoder")
  637. def get_input_embeddings(self):
  638. return self.decoder.embed_tokens
  639. def set_input_embeddings(self, new_embeddings):
  640. self.decoder.set_input_embeddings(new_embeddings)
  641. @unpack_inputs
  642. def call(
  643. self,
  644. input_ids: TFModelInputType | None = None,
  645. attention_mask: np.ndarray | tf.Tensor | None = None,
  646. head_mask: np.ndarray | tf.Tensor | None = None,
  647. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  648. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  649. use_cache: Optional[bool] = None,
  650. output_attentions: Optional[bool] = None,
  651. output_hidden_states: Optional[bool] = None,
  652. return_dict: Optional[bool] = None,
  653. training: Optional[bool] = False,
  654. **kwargs,
  655. ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
  656. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  657. output_hidden_states = (
  658. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  659. )
  660. use_cache = use_cache if use_cache is not None else self.config.use_cache
  661. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  662. outputs = self.decoder(
  663. input_ids,
  664. attention_mask=attention_mask,
  665. head_mask=head_mask,
  666. past_key_values=past_key_values,
  667. inputs_embeds=inputs_embeds,
  668. use_cache=use_cache,
  669. output_attentions=output_attentions,
  670. output_hidden_states=output_hidden_states,
  671. return_dict=return_dict,
  672. training=training,
  673. )
  674. if not return_dict:
  675. return outputs
  676. return TFBaseModelOutputWithPast(
  677. last_hidden_state=outputs.last_hidden_state,
  678. past_key_values=outputs.past_key_values,
  679. hidden_states=outputs.hidden_states,
  680. attentions=outputs.attentions,
  681. )
  682. def build(self, input_shape=None):
  683. if self.built:
  684. return
  685. self.built = True
  686. if getattr(self, "decoder", None) is not None:
  687. with tf.name_scope(self.decoder.name):
  688. self.decoder.build(None)
  689. @add_start_docstrings(
  690. "The bare TF OPT Model outputting raw hidden-states without any specific head on top.",
  691. OPT_START_DOCSTRING,
  692. )
  693. @keras_serializable
  694. class TFOPTModel(TFOPTPreTrainedModel):
  695. config_class = OPTConfig
  696. def __init__(self, config: OPTConfig, **kwargs):
  697. super().__init__(config, **kwargs)
  698. self.config = config
  699. self.model = TFOPTMainLayer(config, name="model")
  700. def get_input_embeddings(self):
  701. return self.model.decoder.embed_tokens
  702. def set_input_embeddings(self, new_embeddings):
  703. self.model.set_input_embeddings(new_embeddings)
  704. @unpack_inputs
  705. @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
  706. @add_code_sample_docstrings(
  707. checkpoint=_CHECKPOINT_FOR_DOC,
  708. output_type=TFBaseModelOutputWithPast,
  709. config_class=_CONFIG_FOR_DOC,
  710. expected_output=_EXPECTED_OUTPUT_SHAPE,
  711. )
  712. def call(
  713. self,
  714. input_ids: TFModelInputType | None = None,
  715. attention_mask: np.ndarray | tf.Tensor | None = None,
  716. head_mask: np.ndarray | tf.Tensor | None = None,
  717. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  718. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  719. use_cache: Optional[bool] = None,
  720. output_attentions: Optional[bool] = None,
  721. output_hidden_states: Optional[bool] = None,
  722. return_dict: Optional[bool] = None,
  723. training: Optional[bool] = False,
  724. **kwargs,
  725. ) -> Union[TFBaseModelOutputWithPast, Tuple[tf.Tensor]]:
  726. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  727. output_hidden_states = (
  728. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  729. )
  730. use_cache = use_cache if use_cache is not None else self.config.use_cache
  731. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  732. outputs = self.model(
  733. input_ids,
  734. attention_mask=attention_mask,
  735. head_mask=head_mask,
  736. past_key_values=past_key_values,
  737. inputs_embeds=inputs_embeds,
  738. use_cache=use_cache,
  739. output_attentions=output_attentions,
  740. output_hidden_states=output_hidden_states,
  741. return_dict=return_dict,
  742. training=training,
  743. )
  744. if not return_dict:
  745. return outputs
  746. return TFBaseModelOutputWithPast(
  747. last_hidden_state=outputs.last_hidden_state,
  748. past_key_values=outputs.past_key_values,
  749. hidden_states=outputs.hidden_states,
  750. attentions=outputs.attentions,
  751. )
  752. def serving_output(self, output):
  753. pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
  754. hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
  755. attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
  756. return TFBaseModelOutputWithPast(
  757. last_hidden_state=output.last_hidden_state,
  758. past_key_values=pkv,
  759. hidden_states=hs,
  760. attentions=attns,
  761. )
  762. def build(self, input_shape=None):
  763. if self.built:
  764. return
  765. self.built = True
  766. if getattr(self, "model", None) is not None:
  767. with tf.name_scope(self.model.name):
  768. self.model.build(None)
  769. @add_start_docstrings(
  770. """
  771. The OPT Model transformer with a language modeling head on top.
  772. """,
  773. OPT_START_DOCSTRING,
  774. )
  775. @keras_serializable
  776. class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
  777. config_class = OPTConfig
  778. def __init__(self, config: OPTConfig, **kwargs):
  779. super().__init__(config, **kwargs)
  780. self.config = config
  781. self.model = TFOPTMainLayer(config, name="model")
  782. def get_output_embeddings(self):
  783. return self.model.get_input_embeddings()
  784. def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs):
  785. attention_mask = kwargs.get("attention_mask", None)
  786. # only last token for inputs_ids if past is defined in kwargs
  787. if past_key_values:
  788. inputs = tf.expand_dims(inputs[:, -1], -1)
  789. return {
  790. "input_ids": inputs,
  791. "attention_mask": attention_mask,
  792. "past_key_values": past_key_values,
  793. "use_cache": use_cache,
  794. }
  795. @unpack_inputs
  796. @replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  797. @add_code_sample_docstrings(
  798. checkpoint=_CHECKPOINT_FOR_DOC,
  799. output_type=TFCausalLMOutputWithPast,
  800. config_class=_CONFIG_FOR_DOC,
  801. expected_output=_CAUSAL_LM_EXPECTED_OUTPUT,
  802. )
  803. def call(
  804. self,
  805. input_ids: TFModelInputType | None = None,
  806. past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
  807. attention_mask: np.ndarray | tf.Tensor | None = None,
  808. position_ids: np.ndarray | tf.Tensor | None = None,
  809. head_mask: np.ndarray | tf.Tensor | None = None,
  810. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  811. labels: np.ndarray | tf.Tensor | None = None,
  812. use_cache: Optional[bool] = None,
  813. output_attentions: Optional[bool] = None,
  814. output_hidden_states: Optional[bool] = None,
  815. return_dict: Optional[bool] = None,
  816. training: Optional[bool] = False,
  817. **kwargs,
  818. ) -> Union[TFCausalLMOutputWithPast, Tuple[tf.Tensor]]:
  819. r"""
  820. Args:
  821. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  822. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  823. provide it.
  824. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  825. [`PreTrainedTokenizer.__call__`] for details.
  826. [What are input IDs?](../glossary#input-ids)
  827. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  828. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  829. - 1 for tokens that are **not masked**,
  830. - 0 for tokens that are **masked**.
  831. [What are attention masks?](../glossary#attention-mask)
  832. head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
  833. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  834. - 1 indicates the head is **not masked**,
  835. - 0 indicates the head is **masked**.
  836. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  837. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  838. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  839. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
  840. tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
  841. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  842. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  843. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
  844. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  845. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  846. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  847. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  848. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  849. than the model's internal embedding lookup matrix.
  850. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  851. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  852. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  853. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  854. use_cache (`bool`, *optional*):
  855. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  856. (see `past_key_values`).
  857. output_attentions (`bool`, *optional*):
  858. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  859. returned tensors for more detail.
  860. output_hidden_states (`bool`, *optional*):
  861. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  862. for more detail.
  863. return_dict (`bool`, *optional*):
  864. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  865. """
  866. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  867. output_hidden_states = (
  868. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  869. )
  870. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  871. outputs = self.model(
  872. input_ids=input_ids,
  873. past_key_values=past_key_values,
  874. attention_mask=attention_mask,
  875. position_ids=position_ids,
  876. head_mask=head_mask,
  877. inputs_embeds=inputs_embeds,
  878. use_cache=use_cache,
  879. output_attentions=output_attentions,
  880. output_hidden_states=output_hidden_states,
  881. return_dict=return_dict,
  882. training=training,
  883. )
  884. logits = self.model.decoder.embed_tokens(outputs[0], mode="linear")
  885. loss = None
  886. if labels is not None:
  887. # shift labels to the left and cut last logit token
  888. shifted_logits = logits[:, :-1]
  889. labels = labels[:, 1:]
  890. loss = self.hf_compute_loss(labels, shifted_logits)
  891. if not return_dict:
  892. output = (logits,) + outputs[1:]
  893. return ((loss,) + output) if loss is not None else output
  894. return TFCausalLMOutputWithPast(
  895. loss=loss,
  896. logits=logits,
  897. past_key_values=outputs.past_key_values,
  898. hidden_states=outputs.hidden_states,
  899. attentions=outputs.attentions,
  900. )
  901. def serving_output(self, output):
  902. pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
  903. hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
  904. attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
  905. return TFCausalLMOutputWithPast(
  906. past_key_values=pkv,
  907. hidden_states=hs,
  908. attentions=attns,
  909. loss=output.loss,
  910. logits=output.logits,
  911. )
  912. def build(self, input_shape=None):
  913. if self.built:
  914. return
  915. self.built = True
  916. if getattr(self, "model", None) is not None:
  917. with tf.name_scope(self.model.name):
  918. self.model.build(None)