modeling_tf_idefics.py 78 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """TF 2.0 Idefics model."""
  21. from __future__ import annotations
  22. from dataclasses import dataclass
  23. from typing import List, Optional, Tuple, Union
  24. import tensorflow as tf
  25. from ... import TFPreTrainedModel
  26. from ...activations_tf import get_tf_activation
  27. from ...modeling_tf_outputs import ModelOutput
  28. from ...modeling_tf_utils import (
  29. TFCausalLanguageModelingLoss,
  30. TFModelInputType,
  31. keras_serializable,
  32. shape_list,
  33. unpack_inputs,
  34. )
  35. from ...tf_utils import invert_attention_mask, scaled_dot_product_attention
  36. from ...utils import (
  37. add_start_docstrings,
  38. add_start_docstrings_to_model_forward,
  39. logging,
  40. replace_return_docstrings,
  41. )
  42. from .configuration_idefics import IdeficsConfig
  43. from .perceiver_tf import TFIdeficsPerceiverResampler
  44. from .vision_tf import TFIdeficsVisionTransformer
  45. logger = logging.get_logger(__name__)
  46. _CONFIG_FOR_DOC = "IdeficsConfig"
  47. @dataclass
  48. class TFIdeficsBaseModelOutputWithPast(ModelOutput):
  49. """
  50. Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding).
  51. Args:
  52. last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  53. Sequence of hidden-states at the output of the last layer of the model.
  54. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  55. hidden_size)` is output.
  56. past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  57. Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  58. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  59. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  60. encoder_sequence_length, embed_size_per_head)`.
  61. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  62. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  63. input) to speed up sequential decoding.
  64. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  65. Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  66. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  67. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  68. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  69. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  70. sequence_length)`.
  71. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  72. heads.
  73. image_hidden_states (`tuple(tf.Tensor)`, *optional*):
  74. Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
  75. sequence_length, hidden_size)`.
  76. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  77. """
  78. last_hidden_state: tf.Tensor = None
  79. past_key_values: Optional[Tuple[Tuple[tf.Tensor]]] = None
  80. hidden_states: Optional[Tuple[tf.Tensor]] = None
  81. attentions: Optional[Tuple[tf.Tensor]] = None
  82. image_hidden_states: Optional[Tuple[tf.Tensor]] = None
  83. @dataclass
  84. class TFIdeficsCausalLMOutputWithPast(ModelOutput):
  85. """
  86. Base class for Idefics causal language model (or autoregressive) outputs.
  87. Args:
  88. loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  89. Language modeling loss (for next-token prediction).
  90. logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  91. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  92. past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  93. Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  94. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  95. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  96. `past_key_values` input) to speed up sequential decoding.
  97. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  98. Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
  99. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  100. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  101. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  102. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  103. sequence_length)`.
  104. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  105. heads.
  106. image_hidden_states (`tuple(tf.Tensor)`, *optional*):
  107. Tuple of `tf.Tensor` (one for the output of the image embeddings, `(batch_size, num_images,
  108. sequence_length, hidden_size)`.
  109. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  110. """
  111. loss: Optional[tf.Tensor] = None
  112. logits: tf.Tensor = None
  113. past_key_values: Optional[List[tf.Tensor]] = None
  114. hidden_states: Optional[Tuple[tf.Tensor]] = None
  115. attentions: Optional[Tuple[tf.Tensor]] = None
  116. image_hidden_states: Optional[Tuple[tf.Tensor]] = None
  117. def expand_inputs_for_generation(
  118. input_ids,
  119. expand_size=1,
  120. is_encoder_decoder=False,
  121. attention_mask=None,
  122. encoder_outputs=None,
  123. **model_kwargs,
  124. ):
  125. expanded_return_idx = tf.reshape(tf.repeat(tf.range(tf.shape(input_ids)[0]), expand_size), [-1])
  126. input_ids = tf.gather(input_ids, expanded_return_idx)
  127. model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
  128. model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None)
  129. model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None)
  130. model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
  131. if "token_type_ids" in model_kwargs:
  132. token_type_ids = model_kwargs["token_type_ids"]
  133. model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx)
  134. if attention_mask is not None:
  135. model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx)
  136. if model_kwargs["image_attention_mask"] is not None:
  137. model_kwargs["image_attention_mask"] = tf.gather(model_kwargs["image_attention_mask"], expanded_return_idx)
  138. if model_kwargs["pixel_values"] is not None:
  139. model_kwargs["pixel_values"] = tf.gather(model_kwargs["pixel_values"], expanded_return_idx)
  140. elif model_kwargs["image_encoder_embeddings"] is not None:
  141. model_kwargs["image_encoder_embeddings"] = tf.gather(
  142. model_kwargs["image_encoder_embeddings"], expanded_return_idx
  143. )
  144. elif model_kwargs["perceiver_embeddings"] is not None:
  145. model_kwargs["perceiver_embeddings"] = tf.gather(model_kwargs["perceiver_embeddings"], expanded_return_idx)
  146. return input_ids, model_kwargs
  147. def update_model_kwargs_for_generation(outputs, model_kwargs):
  148. # must have this key set to at least None
  149. if "past_key_values" in outputs:
  150. model_kwargs["past_key_values"] = outputs.past_key_values
  151. else:
  152. model_kwargs["past_key_values"] = None
  153. # update token_type_ids with last value
  154. if "token_type_ids" in model_kwargs:
  155. token_type_ids = model_kwargs["token_type_ids"]
  156. model_kwargs["token_type_ids"] = tf.concat([token_type_ids, token_type_ids[:, -1:, ...]], axis=-1)
  157. # update attention masks
  158. if "attention_mask" in model_kwargs:
  159. attention_mask = model_kwargs["attention_mask"]
  160. model_kwargs["attention_mask"] = tf.concat(
  161. [attention_mask, tf.ones_like(attention_mask[:, -1:, ...])], axis=-1
  162. )
  163. if "image_attention_mask" in model_kwargs:
  164. image_attention_mask = model_kwargs["image_attention_mask"]
  165. last_mask = image_attention_mask[:, -1:, ...]
  166. model_kwargs["image_attention_mask"] = last_mask
  167. # Get the precomputed image_hidden_states
  168. model_kwargs["image_hidden_states"] = outputs.image_hidden_states
  169. return model_kwargs
  170. def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
  171. token_type_ids = kwargs.get("token_type_ids", None)
  172. # only last token for inputs_ids if past is defined in kwargs
  173. if past_key_values is not None:
  174. input_ids = input_ids[:, -1:]
  175. if token_type_ids is not None:
  176. token_type_ids = token_type_ids[:, -1:]
  177. attention_mask = kwargs.get("attention_mask", None)
  178. position_ids = kwargs.get("position_ids", None)
  179. if attention_mask is not None and position_ids is None:
  180. # create position_ids on the fly for batch generation
  181. position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int64), axis=-1) - 1
  182. position_ids = tf.where(attention_mask == 0, 1, position_ids)
  183. if past_key_values is not None:
  184. position_ids = position_ids[:, -1:]
  185. pixel_values = kwargs.get("pixel_values", None)
  186. image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
  187. perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
  188. image_attention_mask = kwargs.get("image_attention_mask", None)
  189. interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)
  190. return {
  191. "input_ids": input_ids,
  192. "past_key_values": past_key_values,
  193. "use_cache": kwargs.get("use_cache"),
  194. "position_ids": position_ids,
  195. "attention_mask": attention_mask,
  196. "token_type_ids": token_type_ids,
  197. "pixel_values": pixel_values,
  198. "image_encoder_embeddings": image_encoder_embeddings,
  199. "perceiver_embeddings": perceiver_embeddings,
  200. "image_attention_mask": image_attention_mask,
  201. "interpolate_pos_encoding": interpolate_pos_encoding,
  202. }
  203. def freeze_model(model, module_exceptions=[]):
  204. mapping = {
  205. "LayerNorm": tf.keras.layers.LayerNormalization,
  206. "Dense": tf.keras.layers.Dense,
  207. "Embedding": tf.keras.layers.Embedding,
  208. }
  209. module_exceptions_mapped = [mapping[m] for m in module_exceptions]
  210. if not hasattr(model, "layers"):
  211. model.trainable = False # It is just a layer
  212. return model
  213. for layer in model.layers:
  214. if module_exceptions and any(isinstance(layer, t) for t in module_exceptions_mapped):
  215. layer.trainable = True # Explicitly setting it to true to avoid any mistakes
  216. else:
  217. layer.trainable = False
  218. return model
  219. class TFIdeficsDecoupledEmbedding(tf.keras.layers.Embedding):
  220. """
  221. Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
  222. regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
  223. then it will create `num_additional_embeddings` additional parameters that are always trained. If
  224. `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Embedding`.
  225. """
  226. def __init__(
  227. self,
  228. num_embeddings,
  229. num_additional_embeddings,
  230. embedding_dim,
  231. partially_freeze: Optional[bool] = False,
  232. dtype=None,
  233. **kwargs,
  234. ) -> None:
  235. """
  236. Args:
  237. num_embeddings (`int`):
  238. Size of the dictionary of embeddings
  239. num_additional_embeddings (`int`):
  240. Number of additional embeddings. Only useful when you `partially_freeze=True`.
  241. embedding_dim (`int`):
  242. The size of each embedding vector
  243. partially_freeze: (`bool`, *optional*, defaults to `False`):
  244. If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
  245. Note: there are a lot of other parameters to initialize a standard `tf.keras.layers.Embedding` such as `mask_zero`,
  246. `input_length` or `embeddings_initializer`. We are not supporting these.
  247. """
  248. super().__init__(
  249. input_dim=num_embeddings,
  250. output_dim=embedding_dim,
  251. dtype=dtype,
  252. **kwargs,
  253. )
  254. self.num_embeddings = num_embeddings
  255. self.num_additional_embeddings = num_additional_embeddings
  256. self.partially_freeze = partially_freeze
  257. if partially_freeze:
  258. self.trainable = False
  259. if self.num_additional_embeddings > 0:
  260. self.additional_embedding = tf.keras.layers.Embedding(
  261. input_dim=self.num_additional_embeddings,
  262. output_dim=embedding_dim,
  263. dtype=dtype,
  264. name="additional_embedding",
  265. )
  266. def call(self, input_ids):
  267. """
  268. we have 2 embeddings, with different indices - one pretrained self.weight and another
  269. self.additional_embedding.weight that is being trained.
  270. in order to make a lookup of the input ids, we:
  271. 1. find out the indices of the entries belonging to the 2nd embedding
  272. 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
  273. embedding starts from 0 and not num_embeddings
  274. 3. perform the 2nd embedding lookup
  275. 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
  276. 5. perform the 1st embedding lookup
  277. 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
  278. note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
  279. then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
  280. i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
  281. usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
  282. measure.
  283. """
  284. if self.num_additional_embeddings == 0:
  285. return super().call(input_ids)
  286. # Clone so that we don't modify the original input_ids later on
  287. input_ids = tf.identity(input_ids)
  288. additional_vocab_indices = tf.where(input_ids >= self.num_embeddings)
  289. input_ids_additional_vocab = tf.gather_nd(input_ids, additional_vocab_indices)
  290. additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
  291. # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
  292. input_ids = tf.tensor_scatter_nd_update(
  293. input_ids,
  294. additional_vocab_indices,
  295. # tensor filled with 0, having the same length as additional_vocab_indices
  296. tf.zeros(tf.shape(additional_vocab_indices)[0], dtype=input_ids.dtype),
  297. )
  298. full_vector = super().call(input_ids)
  299. # overwrite the records with high indices
  300. full_vector = tf.tensor_scatter_nd_update(full_vector, additional_vocab_indices, additional_embeddings)
  301. return full_vector
  302. def extra_repr(self) -> str:
  303. return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
  304. self.num_embeddings,
  305. self.num_additional_embeddings,
  306. self.output_dim,
  307. self.partially_freeze,
  308. )
  309. class TFIdeficsDecoupledLinear(tf.keras.layers.Layer):
  310. """
  311. Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
  312. regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
  313. then it will create `out_additional_features * in_features` additional parameters that are always trained. If
  314. `out_additional_features=0`, then the module defaults back to the regular behavior of `tf.keras.layers.Dense`.
  315. """
  316. def __init__(
  317. self,
  318. in_features: int,
  319. out_features: int,
  320. out_additional_features: int = 0,
  321. bias: bool = True,
  322. partially_freeze: bool = True,
  323. **kwargs,
  324. ) -> None:
  325. """
  326. out_additional_features: int. Number of additional trainable dimensions. Only makes sense when
  327. `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra
  328. parameters (if any) will be trainable. If False, default to the regular behavior of tf.keras.layers.Dense.
  329. """
  330. super().__init__(**kwargs)
  331. self.out_additional_features = out_additional_features
  332. self.partially_freeze = partially_freeze
  333. self.in_features = in_features
  334. self.out_features = out_features
  335. self.use_bias = bias
  336. if out_additional_features > 0:
  337. self.additional_fc = tf.keras.layers.Dense(
  338. units=out_additional_features, use_bias=bias, name="additional_fc"
  339. )
  340. def call(self, inputs: tf.Tensor) -> tf.Tensor:
  341. output = tf.linalg.matmul(a=inputs, b=self.weight, transpose_b=True)
  342. if self.bias is not None:
  343. output = tf.nn.bias_add(output, self.bias)
  344. if self.out_additional_features > 0:
  345. additional_features = self.additional_fc(inputs)
  346. output = tf.concat([output, additional_features], axis=-1)
  347. return output
  348. def get_config(self):
  349. config = super().get_config()
  350. config.update(
  351. {
  352. "in_features": self.in_features,
  353. "out_features": self.out_features,
  354. "out_additional_features": self.out_additional_features,
  355. "bias": self.bias is not None,
  356. "partially_freeze": self.partially_freeze,
  357. }
  358. )
  359. return config
  360. def extra_repr(self) -> str:
  361. """Overwriting `nn.Linear.extra_repr` to include new parameters."""
  362. return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
  363. self.in_features,
  364. self.out_features,
  365. self.out_additional_features,
  366. self.bias is not None,
  367. self.partially_freeze,
  368. )
  369. @classmethod
  370. def from_config(cls, config):
  371. return cls(**config)
  372. def build(self, input_shape=None):
  373. if self.built:
  374. return
  375. self.built = True
  376. self.weight = self.add_weight(
  377. shape=(self.out_features, self.in_features), trainable=not self.partially_freeze, name="weight"
  378. )
  379. if self.use_bias:
  380. self.bias = self.add_weight(shape=(self.out_features,), trainable=not self.partially_freeze, name="bias")
  381. else:
  382. self.bias = None
  383. if getattr(self, "additional_fc", None) is not None:
  384. with tf.name_scope(self.additional_fc.name):
  385. self.additional_fc.build(self.in_features)
  386. def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
  387. """
  388. Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes.
  389. """
  390. bsz, tgt_len = input_ids_shape
  391. # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask)
  392. mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min)
  393. mask_cond = tf.range(tgt_len)
  394. mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask)
  395. if past_key_values_length > 0:
  396. mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1)
  397. if bsz is None:
  398. # When batch size is dynamic, expand and tile
  399. # so we can compile a functional model
  400. mask = tf.expand_dims(mask, 0)
  401. mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length)
  402. mask = tf.tile(mask, [bsz, 1, 1, 1])
  403. else:
  404. # When batch size is static, directly use broadcast_to
  405. mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length))
  406. return mask
  407. def _expand_mask(mask, dtype, tgt_len=None):
  408. """
  409. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  410. """
  411. bsz, src_len = shape_list(mask)
  412. tgt_len = tgt_len if tgt_len is not None else src_len
  413. expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1)
  414. expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len])
  415. inverted_mask = 1.0 - tf.cast(expanded_mask, dtype)
  416. return tf.where(
  417. tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask
  418. )
  419. class TFIdeficsRMSNorm(tf.keras.layers.Layer):
  420. def __init__(self, hidden_size, eps=1e-6, **kwargs):
  421. """
  422. TFIdeficsRMSNorm is equivalent to T5LayerNorm
  423. """
  424. super().__init__(**kwargs)
  425. self.hidden_size = hidden_size
  426. self.variance_epsilon = eps
  427. def build(self, input_shape):
  428. if self.built:
  429. return
  430. self.built = True
  431. self.weight = self.add_weight(name="weight", shape=[self.hidden_size], initializer="ones")
  432. super().build(input_shape)
  433. def call(self, hidden_states):
  434. variance = tf.math.reduce_mean(tf.math.square(tf.cast(hidden_states, tf.float32)), axis=-1, keepdims=True)
  435. hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
  436. # convert into half-precision if necessary
  437. if self.weight.dtype in [tf.float16, tf.bfloat16]:
  438. hidden_states = tf.cast(hidden_states, self.weight.dtype)
  439. return self.weight * hidden_states
  440. class TFIdeficsEmbedding(tf.keras.layers.Layer):
  441. def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs):
  442. super().__init__(**kwargs)
  443. self.dim = dim
  444. self.max_position_embeddings = max_position_embeddings
  445. self.base = base
  446. self.inv_freq = tf.constant(
  447. 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
  448. )
  449. def _compute_cos_sin(self, seq_len):
  450. t = tf.range(seq_len, dtype=self.inv_freq.dtype)
  451. freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication
  452. emb = tf.concat((freqs, freqs), axis=-1)
  453. return tf.cos(emb), tf.sin(emb)
  454. def call(self, x, seq_len=None):
  455. # x: [bs, num_attention_heads, seq_len, head_size]
  456. if seq_len is None:
  457. seq_len = shape_list(x)[2]
  458. return self._compute_cos_sin(seq_len=seq_len)
  459. def rotate_half(x):
  460. """Rotates half the hidden dims of the input."""
  461. x1 = x[..., : x.shape[-1] // 2]
  462. x2 = x[..., x.shape[-1] // 2 :]
  463. return tf.concat((-x2, x1), axis=-1)
  464. def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
  465. cos = tf.gather(cos, position_ids) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
  466. sin = tf.gather(sin, position_ids)
  467. cos = tf.expand_dims(cos, 1)
  468. sin = tf.expand_dims(sin, 1)
  469. q_embed = (q * cos) + (rotate_half(q) * sin)
  470. k_embed = (k * cos) + (rotate_half(k) * sin)
  471. return q_embed, k_embed
  472. class TFIdeficsMLP(tf.keras.layers.Layer):
  473. def __init__(
  474. self,
  475. hidden_size: int,
  476. intermediate_size: int,
  477. hidden_act: str,
  478. **kwargs,
  479. ):
  480. super().__init__(**kwargs)
  481. self.gate_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="gate_proj")
  482. self.down_proj = tf.keras.layers.Dense(hidden_size, use_bias=False, name="down_proj")
  483. self.up_proj = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="up_proj")
  484. self.act_fn = get_tf_activation(hidden_act)
  485. self.intermediate_size = intermediate_size
  486. self.hidden_size = hidden_size
  487. def call(self, x):
  488. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  489. def build(self, input_shape=None):
  490. if self.built:
  491. return
  492. self.built = True
  493. if getattr(self, "gate_proj", None) is not None:
  494. with tf.name_scope(self.gate_proj.name):
  495. self.gate_proj.build(self.hidden_size)
  496. if getattr(self, "down_proj", None) is not None:
  497. with tf.name_scope(self.down_proj.name):
  498. self.down_proj.build(self.intermediate_size)
  499. if getattr(self, "up_proj", None) is not None:
  500. with tf.name_scope(self.up_proj.name):
  501. self.up_proj.build(self.hidden_size)
  502. class TFIdeficsAttention(tf.keras.layers.Layer):
  503. """Multi-headed attention from 'Attention Is All You Need' paper"""
  504. def __init__(
  505. self,
  506. hidden_size: int,
  507. num_heads: int,
  508. dropout: float = 0.0,
  509. is_cross_attention: bool = False,
  510. config: IdeficsConfig = None,
  511. qk_layer_norms: bool = False,
  512. **kwargs,
  513. ):
  514. super().__init__(**kwargs)
  515. self.hidden_size = hidden_size
  516. self.num_heads = num_heads
  517. self.head_dim = hidden_size // num_heads
  518. self.dropout = dropout
  519. self.config = config
  520. self.is_causal = True
  521. if (self.head_dim * num_heads) != self.hidden_size:
  522. raise ValueError(
  523. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  524. f" and `num_heads`: {num_heads})."
  525. )
  526. self.is_cross_attention = is_cross_attention
  527. self.q_proj = tf.keras.layers.Dense(
  528. num_heads * self.head_dim,
  529. use_bias=False,
  530. name="q_proj",
  531. )
  532. self.k_proj = tf.keras.layers.Dense(
  533. num_heads * self.head_dim,
  534. use_bias=False,
  535. name="k_proj",
  536. )
  537. self.v_proj = tf.keras.layers.Dense(
  538. num_heads * self.head_dim,
  539. use_bias=False,
  540. name="v_proj",
  541. )
  542. self.o_proj = tf.keras.layers.Dense(
  543. hidden_size,
  544. use_bias=False,
  545. name="o_proj",
  546. )
  547. self.rotary_emb = TFIdeficsEmbedding(self.head_dim, name="rotary_emb")
  548. self.qk_layer_norms = qk_layer_norms
  549. if self.qk_layer_norms:
  550. self.q_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="q_layer_norm")
  551. self.k_layer_norm = TFIdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps, name="k_layer_norm")
  552. def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
  553. return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), perm=[0, 2, 1, 3])
  554. def call(
  555. self,
  556. hidden_states: tf.Tensor,
  557. key_value_states: Optional[tf.Tensor] = None,
  558. attention_mask: Optional[tf.Tensor] = None,
  559. position_ids: Optional[tf.Tensor] = None,
  560. past_key_value: Optional[Tuple[tf.Tensor]] = None,
  561. output_attentions: bool = False,
  562. use_cache: bool = False,
  563. ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]:
  564. # if key_value_states are provided this layer is used as a cross-attention layer
  565. is_cross_attention = self.is_cross_attention or key_value_states is not None
  566. bsz, q_len, _ = shape_list(hidden_states)
  567. query_states = self._shape(self.q_proj(hidden_states), q_len, bsz)
  568. if not is_cross_attention:
  569. key_states = self._shape(self.k_proj(hidden_states), q_len, bsz)
  570. value_states = self._shape(self.v_proj(hidden_states), q_len, bsz)
  571. else:
  572. _, kv_len, _ = shape_list(key_value_states) # Note that, in this case, `kv_len` == `kv_seq_len`
  573. key_states = self._shape(self.k_proj(key_value_states), kv_len, bsz)
  574. value_states = self._shape(self.v_proj(key_value_states), kv_len, bsz)
  575. kv_seq_len = shape_list(key_states)[-2]
  576. if past_key_value is not None:
  577. kv_seq_len += shape_list(past_key_value[0])[-2]
  578. if not is_cross_attention:
  579. # Below is to allow symbolic tensors compilation
  580. if tf.is_tensor(kv_seq_len):
  581. seq_len = tf.reduce_max(kv_seq_len, q_len)
  582. else:
  583. seq_len = max(kv_seq_len, q_len)
  584. cos, sin = self.rotary_emb(value_states, seq_len)
  585. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  586. # [bsz, nh, t, hd]
  587. if past_key_value is not None:
  588. # reuse k, v, self_attention
  589. key_states = tf.concat([past_key_value[0], key_states], axis=2)
  590. value_states = tf.concat([past_key_value[1], value_states], axis=2)
  591. past_key_value = (key_states, value_states) if use_cache else None
  592. if self.qk_layer_norms:
  593. query_states = self.q_layer_norm(query_states)
  594. key_states = self.k_layer_norm(key_states)
  595. tf.debugging.assert_equal(
  596. tf.shape(attention_mask),
  597. [bsz, 1, q_len, kv_seq_len],
  598. message=f"Attention weights should be of size {[bsz, 1, q_len, kv_seq_len]}, but is {tf.shape(attention_mask)}",
  599. )
  600. attn_output = scaled_dot_product_attention(
  601. query_states,
  602. key_states,
  603. value_states,
  604. attn_mask=attention_mask,
  605. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  606. is_causal=self.is_causal and attention_mask is None and q_len > 1,
  607. )
  608. tf.debugging.assert_equal(
  609. tf.shape(attn_output),
  610. [bsz, self.num_heads, q_len, self.head_dim],
  611. message=f"Attention weights should be of size {[bsz, self.num_heads, q_len, self.head_dim]}, but is {tf.shape(attn_output)}",
  612. )
  613. attn_output = tf.reshape(tf.transpose(attn_output, perm=[0, 2, 1, 3]), (bsz, q_len, self.hidden_size))
  614. attn_output = self.o_proj(attn_output)
  615. attn_weights = None
  616. if output_attentions:
  617. logger.warning_once(
  618. "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
  619. )
  620. return attn_output, attn_weights, past_key_value
  621. def build(self, input_shape=None):
  622. if self.built:
  623. return
  624. self.built = True
  625. if self.is_cross_attention:
  626. kv_input_dim = (
  627. self.hidden_size
  628. if not hasattr(self.config.vision_config, "embed_dim")
  629. else self.config.vision_config.embed_dim
  630. )
  631. else:
  632. kv_input_dim = self.hidden_size
  633. if getattr(self, "o_proj", None) is not None:
  634. with tf.name_scope(self.o_proj.name):
  635. self.o_proj.build(self.num_heads * self.head_dim)
  636. if getattr(self, "q_proj", None) is not None:
  637. with tf.name_scope(self.q_proj.name):
  638. self.q_proj.build(self.hidden_size)
  639. if getattr(self, "k_proj", None) is not None:
  640. with tf.name_scope(self.k_proj.name):
  641. self.k_proj.build(kv_input_dim)
  642. if getattr(self, "v_proj", None) is not None:
  643. with tf.name_scope(self.v_proj.name):
  644. self.v_proj.build(kv_input_dim)
  645. if getattr(self, "rotary_emb", None) is not None:
  646. with tf.name_scope(self.rotary_emb.name):
  647. self.rotary_emb.build(None)
  648. class TFIdeficsDecoderLayer(tf.keras.layers.Layer):
  649. def __init__(self, config: IdeficsConfig, **kwargs):
  650. super().__init__(**kwargs)
  651. self.hidden_size = config.hidden_size
  652. self.self_attn = TFIdeficsAttention(
  653. hidden_size=self.hidden_size,
  654. num_heads=config.num_attention_heads,
  655. dropout=config.dropout,
  656. config=config,
  657. name="self_attn",
  658. )
  659. self.mlp = TFIdeficsMLP(
  660. hidden_size=self.hidden_size,
  661. intermediate_size=config.intermediate_size,
  662. hidden_act=config.hidden_act,
  663. name="mlp",
  664. )
  665. self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
  666. self.post_attention_layernorm = TFIdeficsRMSNorm(
  667. config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
  668. )
  669. self.dropout = config.dropout
  670. def call(
  671. self,
  672. hidden_states: tf.Tensor,
  673. attention_mask: Optional[tf.Tensor] = None,
  674. position_ids: Optional[tf.Tensor] = None,
  675. past_key_value: Optional[Tuple[tf.Tensor]] = None,
  676. output_attentions: Optional[bool] = False,
  677. use_cache: Optional[bool] = False,
  678. training=False,
  679. ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]:
  680. """
  681. Args:
  682. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  683. attention_mask (`tf.Tensor`, *optional*): attention mask of size
  684. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  685. output_attentions (`bool`, *optional*):
  686. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  687. returned tensors for more detail.
  688. use_cache (`bool`, *optional*):
  689. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  690. (see `past_key_values`).
  691. past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
  692. """
  693. residual = hidden_states
  694. hidden_states = self.input_layernorm(hidden_states)
  695. # Self Attention
  696. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  697. hidden_states=hidden_states,
  698. attention_mask=attention_mask,
  699. position_ids=position_ids,
  700. past_key_value=past_key_value,
  701. output_attentions=output_attentions,
  702. use_cache=use_cache,
  703. )
  704. hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
  705. hidden_states = residual + hidden_states
  706. # Fully Connected
  707. residual = hidden_states
  708. hidden_states = self.post_attention_layernorm(hidden_states)
  709. hidden_states = self.mlp(hidden_states)
  710. hidden_states = tf.nn.dropout(hidden_states, rate=self.dropout)
  711. hidden_states = residual + hidden_states
  712. outputs = (hidden_states,)
  713. if output_attentions:
  714. outputs += (self_attn_weights,)
  715. if use_cache:
  716. outputs += (present_key_value,)
  717. return outputs
  718. def build(self, input_shape=None):
  719. if self.built:
  720. return
  721. self.built = True
  722. if getattr(self, "self_attn", None) is not None:
  723. with tf.name_scope(self.self_attn.name):
  724. self.self_attn.build(None)
  725. if getattr(self, "mlp", None) is not None:
  726. with tf.name_scope(self.mlp.name):
  727. self.mlp.build(None)
  728. if getattr(self, "input_layernorm", None) is not None:
  729. with tf.name_scope(self.input_layernorm.name):
  730. self.input_layernorm.build(None)
  731. if getattr(self, "post_attention_layernorm", None) is not None:
  732. with tf.name_scope(self.post_attention_layernorm.name):
  733. self.post_attention_layernorm.build(None)
  734. class TFIdeficsGatedCrossAttentionLayer(tf.keras.layers.Layer):
  735. def __init__(self, config: IdeficsConfig, **kwargs):
  736. super().__init__(**kwargs)
  737. self.hidden_size = config.hidden_size
  738. self.cross_attn = TFIdeficsAttention(
  739. hidden_size=self.hidden_size,
  740. num_heads=config.num_attention_heads,
  741. is_cross_attention=True,
  742. dropout=config.dropout,
  743. config=config,
  744. qk_layer_norms=config.qk_layer_norms,
  745. name="cross_attn",
  746. )
  747. self.mlp = TFIdeficsMLP(
  748. hidden_size=self.hidden_size,
  749. intermediate_size=config.intermediate_size,
  750. hidden_act=config.hidden_act,
  751. name="mlp",
  752. )
  753. self.input_layernorm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm")
  754. self.post_attention_layernorm = TFIdeficsRMSNorm(
  755. config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm"
  756. )
  757. self.config = config.dropout
  758. self.act_cross_attn = tf.keras.activations.tanh
  759. self.act_dense = tf.keras.activations.tanh
  760. self.alpha_initializer = config.alpha_initializer
  761. self.alpha_type = config.alpha_type
  762. self.alphas_initializer_range = config.alphas_initializer_range
  763. def build(self, input_shape):
  764. if self.built:
  765. return
  766. self.built = True
  767. if self.alpha_initializer == "zeros":
  768. if self.alpha_type == "vector":
  769. self.alpha_cross_attn = self.add_weight(
  770. shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_cross_attn"
  771. )
  772. self.alpha_dense = self.add_weight(
  773. shape=(1, 1, self.hidden_size), initializer="zeros", trainable=True, name="alpha_dense"
  774. )
  775. elif self.alpha_type == "float":
  776. self.alpha_cross_attn = self.add_weight(
  777. shape=(1,), initializer="zeros", trainable=True, name="alpha_cross_attn"
  778. )
  779. self.alpha_dense = self.add_weight(shape=(1,), initializer="zeros", trainable=True, name="alpha_dense")
  780. else:
  781. raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
  782. elif self.alpha_initializer == "ones":
  783. if self.alpha_type == "vector":
  784. self.alpha_cross_attn = self.add_weight(
  785. shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_cross_attn"
  786. )
  787. self.alpha_dense = self.add_weight(
  788. shape=(1, 1, self.hidden_size), initializer="ones", trainable=True, name="alpha_dense"
  789. )
  790. elif self.alpha_type == "float":
  791. self.alpha_cross_attn = self.add_weight(
  792. shape=(1,), initializer="ones", trainable=True, name="alpha_cross_attn"
  793. )
  794. self.alpha_dense = self.add_weight(shape=(1,), initializer="ones", trainable=True, name="alpha_dense")
  795. else:
  796. raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
  797. elif self.alpha_initializer in {"normal", "gaussian", "random"}:
  798. if self.alpha_type == "vector":
  799. self.alpha_cross_attn = self.add_weight(
  800. shape=(1, 1, self.hidden_size),
  801. initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
  802. trainable=True,
  803. name="alpha_cross_attn",
  804. )
  805. self.alpha_dense = self.add_weight(
  806. shape=(1, 1, self.hidden_size),
  807. initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
  808. trainable=True,
  809. name="alpha_dense",
  810. )
  811. elif self.alpha_type == "float":
  812. self.alpha_cross_attn = self.add_weight(
  813. shape=(1,),
  814. initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
  815. trainable=True,
  816. name="alpha_type",
  817. )
  818. self.alpha_dense = self.add_weight(
  819. shape=(1,),
  820. initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=self.alphas_initializer_range),
  821. trainable=True,
  822. name="alpha_dense",
  823. )
  824. else:
  825. raise ValueError(f"Unknown value for `alpha_type` ({self.alpha_type})")
  826. else:
  827. raise NotImplementedError(f"Alpha initialization scheme {self.alpha_initializer} not yet implemented!")
  828. if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
  829. raise ValueError("Alpha parameters not initialized correctly!")
  830. with tf.name_scope(self.cross_attn.name):
  831. self.cross_attn.build(None)
  832. with tf.name_scope(self.mlp.name):
  833. self.mlp.build(None)
  834. with tf.name_scope(self.input_layernorm.name):
  835. self.input_layernorm.build(None)
  836. with tf.name_scope(self.post_attention_layernorm.name):
  837. self.post_attention_layernorm.build(None)
  838. super().build(input_shape)
  839. def call(
  840. self,
  841. hidden_states: tf.Tensor,
  842. attention_mask: Optional[tf.Tensor] = None,
  843. image_hidden_states: Optional[tf.Tensor] = None,
  844. image_attention_mask: Optional[tf.Tensor] = None,
  845. cross_attention_gate: Optional[tf.Tensor] = None,
  846. output_attentions: Optional[bool] = False,
  847. use_cache: Optional[bool] = False,
  848. past_key_value: Optional[Tuple[tf.Tensor]] = None,
  849. ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]:
  850. """
  851. Args:
  852. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  853. attention_mask (`tf.Tensor`, *optional*): attention mask of size
  854. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  855. output_attentions (`bool`, *optional*):
  856. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  857. returned tensors for more detail.
  858. use_cache (`bool`, *optional*):
  859. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  860. (see `past_key_values`).
  861. past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states
  862. no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored
  863. """
  864. if image_hidden_states is None:
  865. raise ValueError(
  866. "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
  867. " conditioned on."
  868. )
  869. if cross_attention_gate is None:
  870. raise ValueError(
  871. "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images."
  872. )
  873. if past_key_value is not None:
  874. raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
  875. residual = hidden_states
  876. hidden_states = self.input_layernorm(hidden_states)
  877. # Self Attention
  878. hidden_states, self_attn_weights, present_key_value = self.cross_attn(
  879. hidden_states=hidden_states,
  880. key_value_states=image_hidden_states,
  881. attention_mask=image_attention_mask,
  882. output_attentions=output_attentions,
  883. )
  884. hidden_states = tf.nn.dropout(hidden_states, rate=self.config)
  885. mask = tf.cast(cross_attention_gate == 0, dtype=hidden_states.dtype)
  886. # Expand dimensions of mask to match hidden_states
  887. mask = tf.expand_dims(mask, -1)
  888. hidden_states = tf.where(
  889. tf.broadcast_to(mask, tf.shape(hidden_states)) == 1, tf.zeros_like(hidden_states), hidden_states
  890. )
  891. # when there are no images the model is used in pure language mode
  892. # gate = 0 if no_images else 1
  893. hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
  894. # Fully Connected
  895. residual = hidden_states
  896. hidden_states = self.post_attention_layernorm(hidden_states)
  897. hidden_states = self.mlp(hidden_states)
  898. hidden_states = tf.nn.dropout(hidden_states, rate=self.config)
  899. hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
  900. outputs = (hidden_states,)
  901. if output_attentions:
  902. outputs += (self_attn_weights,)
  903. if use_cache:
  904. outputs += (present_key_value,)
  905. return outputs
  906. LLAMA_START_DOCSTRING = r"""
  907. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  908. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  909. etc.)
  910. This model is also a TensorFlow [tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) subclass.
  911. Use it as a regular TensorFlow Layer and refer to the TensorFlow documentation for all matter related to general usage
  912. and behavior.
  913. Parameters:
  914. config ([`IdeficsConfig`]):
  915. Model configuration class with all the parameters of the model. Initializing with a config file does not
  916. load the weights associated with the model, only the configuration. Check out the
  917. [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  918. """
  919. @add_start_docstrings(
  920. "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
  921. LLAMA_START_DOCSTRING,
  922. )
  923. class TFIdeficsPreTrainedModel(TFPreTrainedModel):
  924. config_class = IdeficsConfig
  925. base_model_prefix = "model"
  926. supports_gradient_checkpointing = True
  927. _no_split_modules = ["TFIdeficsDecoderLayer", "TFIdeficsGatedCrossAttentionLayer"]
  928. LLAMA_INPUTS_DOCSTRING = r"""
  929. Args:
  930. input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
  931. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  932. it.
  933. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  934. [`PreTrainedTokenizer.__call__`] for details.
  935. [What are input IDs?](../glossary#input-ids)
  936. attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  937. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  938. - 1 for tokens that are **not masked**,
  939. - 0 for tokens that are **masked**.
  940. [What are attention masks?](../glossary#attention-mask)
  941. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  942. [`PreTrainedTokenizer.__call__`] for details.
  943. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  944. `past_key_values`).
  945. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  946. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  947. information on the default strategy.
  948. - 1 indicates the head is **not masked**,
  949. - 0 indicates the head is **masked**.
  950. position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  951. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  952. config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  953. past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  954. Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  955. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  956. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  957. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  958. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  959. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  960. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  961. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  962. inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  963. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  964. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  965. model's internal embedding lookup matrix.
  966. use_cache (`bool`, *optional*):
  967. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  968. `past_key_values`).
  969. output_attentions (`bool`, *optional*):
  970. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  971. tensors for more detail.
  972. output_hidden_states (`bool`, *optional*):
  973. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  974. more detail.
  975. return_dict (`bool`, *optional*):
  976. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  977. """
  978. @add_start_docstrings(
  979. "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
  980. LLAMA_START_DOCSTRING,
  981. )
  982. @keras_serializable
  983. class TFIdeficsMainLayer(tf.keras.layers.Layer):
  984. """
  985. Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
  986. Args:
  987. config: IdeficsConfig
  988. """
  989. config_class = IdeficsConfig
  990. def __init__(self, config: IdeficsConfig, add_pooling_year: bool = True, **kwargs):
  991. super().__init__(**kwargs)
  992. self.config = config
  993. self.padding_idx = config.pad_token_id
  994. self.vocab_size = config.vocab_size
  995. self.embed_tokens = TFIdeficsDecoupledEmbedding(
  996. num_embeddings=config.vocab_size,
  997. num_additional_embeddings=config.additional_vocab_size,
  998. embedding_dim=config.hidden_size,
  999. partially_freeze=config.freeze_text_layers,
  1000. name="embed_tokens",
  1001. )
  1002. self.image_size = config.vision_config.image_size
  1003. self.vision_config = config.vision_config
  1004. self.vision_model = TFIdeficsVisionTransformer(config.vision_config, name="vision_model")
  1005. # Perceiver Resampler
  1006. if config.use_resampler:
  1007. perceiver_config = config.perceiver_config
  1008. self.perceiver_resampler = TFIdeficsPerceiverResampler(
  1009. config,
  1010. config.vision_config.embed_dim,
  1011. perceiver_config.resampler_depth,
  1012. perceiver_config.resampler_n_heads,
  1013. perceiver_config.resampler_head_dim,
  1014. perceiver_config.resampler_n_latents,
  1015. name="perceiver_resampler",
  1016. )
  1017. self.decoder_layers = [
  1018. TFIdeficsDecoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
  1019. ]
  1020. self.cross_layer_interval = config.cross_layer_interval
  1021. num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
  1022. self.gated_cross_attn_layers = [
  1023. TFIdeficsGatedCrossAttentionLayer(config, name=f"gated_cross_attn_layers.{i}")
  1024. for i in range(num_cross_layers)
  1025. ]
  1026. self.gradient_checkpointing = False
  1027. self.norm = TFIdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm")
  1028. self.gradient_checkpointing = False
  1029. self.freeze_relevant_params(config)
  1030. def freeze_relevant_params(self, config=None):
  1031. if config is None:
  1032. config = self.config
  1033. if config.freeze_text_layers:
  1034. self.freeze_text_layers(config.freeze_text_module_exceptions)
  1035. if config.freeze_vision_layers:
  1036. freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
  1037. def freeze_text_layers(self, module_exceptions=[]):
  1038. for module in [self.decoder_layers, self.norm]:
  1039. freeze_model(module, module_exceptions=module_exceptions)
  1040. def freeze_vision_layers(self, module_exceptions=[]):
  1041. freeze_model(self.vision_model, module_exceptions=module_exceptions)
  1042. def get_input_embeddings(self):
  1043. return self.embed_tokens
  1044. def set_input_embeddings(self, value):
  1045. self.embed_tokens = value
  1046. def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
  1047. # create causal mask
  1048. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1049. combined_attention_mask = None
  1050. # if input_shape[-1] > 1:
  1051. combined_attention_mask = _make_causal_mask(
  1052. input_shape,
  1053. inputs_embeds.dtype,
  1054. past_key_values_length=past_key_values_length,
  1055. )
  1056. if attention_mask is not None:
  1057. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1058. expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
  1059. combined_attention_mask = (
  1060. expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
  1061. )
  1062. return combined_attention_mask
  1063. @unpack_inputs
  1064. @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
  1065. def call(
  1066. self,
  1067. input_ids: TFModelInputType | None = None,
  1068. attention_mask: Optional[tf.Tensor] = None,
  1069. position_ids: Optional[tf.Tensor] = None,
  1070. past_key_values: Optional[List[tf.Tensor]] = None,
  1071. inputs_embeds: Optional[tf.Tensor] = None,
  1072. pixel_values: Optional[tf.Tensor] = None,
  1073. image_encoder_embeddings: Optional[tf.Tensor] = None,
  1074. perceiver_embeddings: Optional[tf.Tensor] = None,
  1075. image_attention_mask: Optional[tf.Tensor] = None,
  1076. use_cache: Optional[bool] = None,
  1077. output_attentions: Optional[bool] = None,
  1078. output_hidden_states: Optional[bool] = None,
  1079. interpolate_pos_encoding: Optional[bool] = False,
  1080. return_dict: Optional[bool] = None,
  1081. training: Optional[bool] = None,
  1082. ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]:
  1083. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1084. output_hidden_states = (
  1085. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1086. )
  1087. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1088. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1089. # retrieve input_ids and inputs_embeds
  1090. if input_ids is not None and inputs_embeds is not None:
  1091. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  1092. elif input_ids is not None:
  1093. batch_size, seq_length = shape_list(input_ids)
  1094. elif inputs_embeds is not None:
  1095. batch_size, seq_length, _ = shape_list(inputs_embeds)
  1096. else:
  1097. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  1098. seq_length_with_past = seq_length
  1099. past_key_values_length = 0
  1100. if past_key_values is not None:
  1101. past_key_values_length = shape_list(past_key_values[0][0])[2]
  1102. seq_length_with_past = seq_length_with_past + past_key_values_length
  1103. if attention_mask is not None and position_ids is None:
  1104. # create position_ids on the fly for batch generation
  1105. position_ids = tf.math.cumsum(tf.cast(attention_mask, dtype=tf.int32), axis=-1) - 1
  1106. position_ids = tf.where(attention_mask == 0, 1, position_ids)
  1107. elif position_ids is None:
  1108. position_ids = tf.range(past_key_values_length, seq_length + past_key_values_length, dtype=tf.int32)
  1109. position_ids = tf.expand_dims(position_ids, 0)
  1110. no_images = False
  1111. if (
  1112. sum((int(pixel_values is None), int(image_encoder_embeddings is None), int(perceiver_embeddings is None)))
  1113. != 2
  1114. ):
  1115. raise ValueError(
  1116. "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None."
  1117. )
  1118. elif pixel_values is not None:
  1119. no_images = tf.reduce_sum(tf.cast(pixel_values, dtype=tf.int32)) == 0
  1120. pixel_values = tf.cast(pixel_values, dtype=self.dtype) # fp16 compatibility
  1121. # Below hack is because when cross-loading pytorch weights, there is an
  1122. # initial forward pass with dummy input and code below is here to handle that
  1123. if len(pixel_values.shape) == 4:
  1124. batch_size = shape_list(pixel_values)[0]
  1125. num_images = shape_list(pixel_values)[0]
  1126. # pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[1:]])
  1127. elif len(pixel_values.shape) == 5:
  1128. batch_size, num_images = shape_list(pixel_values)[:2]
  1129. pixel_values = tf.reshape(pixel_values, [batch_size * num_images, *pixel_values.shape[2:]])
  1130. # Get sequence from the vision encoder
  1131. image_hidden_states = self.vision_model(
  1132. pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  1133. ).last_hidden_state
  1134. elif image_encoder_embeddings is not None:
  1135. batch_size, num_images, image_seq_len, image_hidden_size = shape_list(image_encoder_embeddings)
  1136. image_hidden_states = tf.cast(image_encoder_embeddings, dtype=self.dtype)
  1137. image_hidden_states = tf.reshape(
  1138. image_hidden_states, (batch_size * num_images, image_seq_len, image_hidden_size)
  1139. )
  1140. if self.config.use_resampler:
  1141. if perceiver_embeddings is None:
  1142. perceiver_embeddings = self.perceiver_resampler(image_hidden_states)
  1143. image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)[1:3]
  1144. else:
  1145. batch_size, num_images, image_seq_len, image_hidden_size = shape_list(perceiver_embeddings)
  1146. image_hidden_states = perceiver_embeddings
  1147. elif perceiver_embeddings is None:
  1148. image_seq_len, image_hidden_size = shape_list(image_hidden_states)[1:3]
  1149. else:
  1150. raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True")
  1151. image_hidden_states = tf.reshape(
  1152. image_hidden_states, (batch_size, num_images * image_seq_len, image_hidden_size)
  1153. )
  1154. # # Hack to use the model in full language modeling mode
  1155. # image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32)
  1156. # this is to account for the dummy inputs
  1157. if pixel_values is not None and len(pixel_values.shape) == 4 and image_attention_mask is None:
  1158. image_attention_mask = tf.zeros((batch_size, seq_length, 1), dtype=tf.int32)
  1159. text_seq_len = shape_list(image_attention_mask)[1]
  1160. image_attention_mask = tf.expand_dims(image_attention_mask, -1)
  1161. image_attention_mask = tf.repeat(image_attention_mask, repeats=image_seq_len)
  1162. image_attention_mask = tf.reshape(image_attention_mask, (batch_size, text_seq_len, num_images * image_seq_len))
  1163. if image_hidden_states is not None:
  1164. image_batch_size, image_sequence_length, _ = shape_list(image_hidden_states)
  1165. image_hidden_shape = (image_batch_size, image_sequence_length)
  1166. if image_attention_mask is None:
  1167. image_attention_mask = tf.ones(image_hidden_shape, dtype=tf.int32)
  1168. image_attention_mask = invert_attention_mask(image_attention_mask)
  1169. else:
  1170. image_attention_mask = None
  1171. cross_attention_gate = tf.squeeze(
  1172. tf.cast(tf.reduce_any(image_attention_mask == 0, axis=-1), dtype=self.dtype), axis=1
  1173. )
  1174. if inputs_embeds is None:
  1175. inputs_embeds = self.embed_tokens(input_ids)
  1176. # embed positions
  1177. if attention_mask is None:
  1178. attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool)
  1179. attention_mask = self._prepare_decoder_attention_mask(
  1180. attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
  1181. )
  1182. hidden_states = inputs_embeds
  1183. if self.gradient_checkpointing and training:
  1184. if use_cache:
  1185. logger.warning_once(
  1186. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1187. )
  1188. use_cache = False
  1189. # decoder layers
  1190. all_hidden_states = () if output_hidden_states else None
  1191. all_self_attns = () if output_attentions else None
  1192. next_decoder_cache = () if use_cache else None
  1193. for idx, decoder_layer in enumerate(self.decoder_layers):
  1194. if output_hidden_states:
  1195. all_hidden_states += (hidden_states,)
  1196. past_key_value = past_key_values[idx] if past_key_values is not None else None
  1197. def vblock(
  1198. main_block,
  1199. hidden_states,
  1200. attention_mask,
  1201. position_ids,
  1202. past_key_value,
  1203. image_hidden_states,
  1204. image_attention_mask,
  1205. cross_attention_gate,
  1206. output_attentions,
  1207. use_cache,
  1208. layer_idx,
  1209. cross_layer_interval,
  1210. gated_cross_attn_layers,
  1211. ):
  1212. # TODO(ls): Add cross attention values to respective lists
  1213. if layer_idx % cross_layer_interval == 0:
  1214. xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
  1215. outputs = xblock(
  1216. hidden_states,
  1217. attention_mask=attention_mask,
  1218. image_hidden_states=image_hidden_states,
  1219. image_attention_mask=image_attention_mask,
  1220. cross_attention_gate=cross_attention_gate,
  1221. output_attentions=output_attentions,
  1222. use_cache=use_cache,
  1223. past_key_value=None, # not implemented
  1224. )
  1225. hidden_states = outputs[0]
  1226. layer_outputs = main_block(
  1227. hidden_states,
  1228. attention_mask=attention_mask,
  1229. position_ids=position_ids,
  1230. past_key_value=past_key_value,
  1231. output_attentions=output_attentions,
  1232. use_cache=use_cache,
  1233. )
  1234. return layer_outputs
  1235. if self.gradient_checkpointing and training:
  1236. past_key_value = None
  1237. if use_cache:
  1238. logger.warning_once(
  1239. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1240. )
  1241. use_cache = False
  1242. layer_outputs = tf.recompute_grad(
  1243. vblock,
  1244. decoder_layer,
  1245. hidden_states,
  1246. attention_mask,
  1247. position_ids,
  1248. past_key_value,
  1249. image_hidden_states,
  1250. image_attention_mask,
  1251. output_attentions,
  1252. use_cache,
  1253. no_images,
  1254. idx,
  1255. self.cross_layer_interval,
  1256. self.gated_cross_attn_layers,
  1257. )
  1258. else:
  1259. layer_outputs = vblock(
  1260. decoder_layer,
  1261. hidden_states,
  1262. attention_mask=attention_mask,
  1263. position_ids=position_ids,
  1264. past_key_value=past_key_value,
  1265. image_hidden_states=image_hidden_states,
  1266. image_attention_mask=image_attention_mask,
  1267. cross_attention_gate=cross_attention_gate,
  1268. output_attentions=output_attentions,
  1269. use_cache=use_cache,
  1270. layer_idx=idx,
  1271. cross_layer_interval=self.cross_layer_interval,
  1272. gated_cross_attn_layers=self.gated_cross_attn_layers,
  1273. )
  1274. hidden_states = layer_outputs[0]
  1275. if use_cache:
  1276. next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
  1277. if output_attentions:
  1278. all_self_attns += (layer_outputs[1],)
  1279. hidden_states = self.norm(hidden_states)
  1280. # add hidden states from the last decoder layer
  1281. if output_hidden_states:
  1282. all_hidden_states += (hidden_states,)
  1283. next_cache = next_decoder_cache if use_cache else None
  1284. image_hidden_states = tf.reshape(
  1285. image_hidden_states, (batch_size, num_images, image_seq_len, image_hidden_size)
  1286. )
  1287. if not return_dict:
  1288. return tuple(
  1289. v
  1290. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
  1291. if v is not None
  1292. )
  1293. return TFIdeficsBaseModelOutputWithPast(
  1294. last_hidden_state=hidden_states,
  1295. past_key_values=next_cache,
  1296. hidden_states=all_hidden_states,
  1297. attentions=all_self_attns,
  1298. image_hidden_states=image_hidden_states,
  1299. )
  1300. def build(self, input_shape=None):
  1301. if self.built:
  1302. return
  1303. self.built = True
  1304. if getattr(self, "embed_tokens", None) is not None:
  1305. with tf.name_scope(self.embed_tokens.name):
  1306. self.embed_tokens.build(None)
  1307. if getattr(self, "vision_model", None) is not None:
  1308. with tf.name_scope(self.vision_model.name):
  1309. self.vision_model.build(None)
  1310. if getattr(self, "norm", None) is not None:
  1311. with tf.name_scope(self.norm.name):
  1312. self.norm.build(None)
  1313. if getattr(self, "perceiver_resampler", None) is not None:
  1314. with tf.name_scope(self.perceiver_resampler.name):
  1315. self.perceiver_resampler.build(None)
  1316. if getattr(self, "decoder_layers", None) is not None:
  1317. for layer in self.decoder_layers:
  1318. with tf.name_scope(layer.name):
  1319. layer.build(None)
  1320. if getattr(self, "gated_cross_attn_layers", None) is not None:
  1321. for layer in self.gated_cross_attn_layers:
  1322. with tf.name_scope(layer.name):
  1323. layer.build(None)
  1324. class TFIdeficsModel(TFIdeficsPreTrainedModel):
  1325. def __init__(self, config: IdeficsConfig, *inputs, **kwargs):
  1326. super().__init__(config, *inputs, **kwargs)
  1327. self.model = TFIdeficsMainLayer(config, name="model")
  1328. def call(
  1329. self,
  1330. input_ids: TFModelInputType | None = None,
  1331. attention_mask: Optional[tf.Tensor] = None,
  1332. position_ids: Optional[tf.Tensor] = None,
  1333. past_key_values: Optional[List[tf.Tensor]] = None,
  1334. inputs_embeds: Optional[tf.Tensor] = None,
  1335. pixel_values: Optional[tf.Tensor] = None,
  1336. image_encoder_embeddings: Optional[tf.Tensor] = None,
  1337. perceiver_embeddings: Optional[tf.Tensor] = None,
  1338. image_attention_mask: Optional[tf.Tensor] = None,
  1339. use_cache: Optional[bool] = None,
  1340. output_attentions: Optional[bool] = None,
  1341. output_hidden_states: Optional[bool] = None,
  1342. interpolate_pos_encoding: Optional[bool] = False,
  1343. return_dict: Optional[bool] = None,
  1344. training: Optional[bool] = None,
  1345. ) -> Union[TFIdeficsBaseModelOutputWithPast, Tuple[tf.Tensor]]:
  1346. outputs = self.model(
  1347. input_ids=input_ids,
  1348. attention_mask=attention_mask,
  1349. position_ids=position_ids,
  1350. past_key_values=past_key_values,
  1351. inputs_embeds=inputs_embeds,
  1352. pixel_values=pixel_values,
  1353. image_encoder_embeddings=image_encoder_embeddings,
  1354. perceiver_embeddings=perceiver_embeddings,
  1355. image_attention_mask=image_attention_mask,
  1356. use_cache=use_cache,
  1357. output_attentions=output_attentions,
  1358. output_hidden_states=output_hidden_states,
  1359. interpolate_pos_encoding=interpolate_pos_encoding,
  1360. return_dict=return_dict,
  1361. training=training,
  1362. )
  1363. return outputs
  1364. def build(self, input_shape=None):
  1365. if self.built:
  1366. return
  1367. self.built = True
  1368. if getattr(self, "model", None) is not None:
  1369. with tf.name_scope(self.model.name):
  1370. self.model.build(None)
  1371. class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLoss):
  1372. _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
  1373. _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
  1374. config_class = IdeficsConfig
  1375. def __init__(self, config, vision_model=None, **kwargs):
  1376. super().__init__(config, **kwargs)
  1377. self.model = TFIdeficsMainLayer(config, name="model")
  1378. self.lm_head = TFIdeficsDecoupledLinear(
  1379. config.hidden_size,
  1380. config.vocab_size,
  1381. config.additional_vocab_size,
  1382. bias=False,
  1383. partially_freeze=config.freeze_lm_head,
  1384. name="lm_head",
  1385. )
  1386. def get_input_embeddings(self):
  1387. return self.model.embed_tokens
  1388. def set_input_embeddings(self, value):
  1389. self.model.embed_tokens = value
  1390. def get_output_embeddings(self):
  1391. return self.lm_head
  1392. def set_output_embeddings(self, new_embeddings):
  1393. self.lm_head = new_embeddings
  1394. def set_decoder(self, decoder):
  1395. self.model = decoder
  1396. def get_decoder(self):
  1397. return self.model
  1398. def tie_weights(self):
  1399. """
  1400. Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
  1401. IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.
  1402. """
  1403. output_embeddings = self.get_output_embeddings()
  1404. input_embeddings = self.get_input_embeddings()
  1405. if getattr(self.config, "tie_word_embeddings", True):
  1406. output_embeddings.weight = input_embeddings.weight
  1407. if input_embeddings.num_additional_embeddings > 0:
  1408. assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
  1409. output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
  1410. if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
  1411. output_embeddings.out_features = input_embeddings.num_embeddings
  1412. if hasattr(output_embeddings, "out_additional_features") and hasattr(
  1413. input_embeddings, "num_additional_embeddings"
  1414. ):
  1415. output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
  1416. @unpack_inputs
  1417. @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
  1418. @replace_return_docstrings(output_type=TFIdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1419. def call(
  1420. self,
  1421. input_ids: TFModelInputType | None = None,
  1422. attention_mask: Optional[tf.Tensor] = None,
  1423. position_ids: Optional[tf.Tensor] = None,
  1424. past_key_values: Optional[List[tf.Tensor]] = None,
  1425. inputs_embeds: Optional[tf.Tensor] = None,
  1426. pixel_values: Optional[tf.Tensor] = None,
  1427. image_encoder_embeddings: Optional[tf.Tensor] = None,
  1428. perceiver_embeddings: Optional[tf.Tensor] = None,
  1429. image_attention_mask: Optional[tf.Tensor] = None,
  1430. labels: Optional[tf.Tensor] = None,
  1431. use_cache: Optional[bool] = None,
  1432. output_attentions: Optional[bool] = None,
  1433. output_hidden_states: Optional[bool] = None,
  1434. interpolate_pos_encoding: Optional[bool] = False,
  1435. return_dict: Optional[bool] = None,
  1436. training=False,
  1437. ) -> Union[TFIdeficsCausalLMOutputWithPast, Tuple[tf.Tensor]]:
  1438. r"""
  1439. Args:
  1440. labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1441. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1442. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1443. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1444. Returns:
  1445. Example:
  1446. ```python
  1447. >> from transformers import AutoTokenizer, TFIdeficsForVisionText2Text
  1448. >> model = TFIdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b")
  1449. >> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceM4/idefics-9b")
  1450. >> prompt = "Hey, are you consciours? Can you talk to me?"
  1451. >> inputs = tokenizer(prompt, return_tensors="tf")
  1452. >> # Generate
  1453. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1454. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1455. "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
  1456. ```"""
  1457. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1458. output_hidden_states = (
  1459. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1460. )
  1461. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1462. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1463. outputs = self.model(
  1464. input_ids=input_ids,
  1465. attention_mask=attention_mask,
  1466. position_ids=position_ids,
  1467. past_key_values=past_key_values,
  1468. inputs_embeds=inputs_embeds,
  1469. pixel_values=pixel_values,
  1470. image_encoder_embeddings=image_encoder_embeddings,
  1471. perceiver_embeddings=perceiver_embeddings,
  1472. image_attention_mask=image_attention_mask,
  1473. use_cache=use_cache,
  1474. output_attentions=output_attentions,
  1475. output_hidden_states=output_hidden_states,
  1476. interpolate_pos_encoding=interpolate_pos_encoding,
  1477. return_dict=return_dict,
  1478. training=training,
  1479. )
  1480. hidden_states = outputs[0]
  1481. logits = self.lm_head(hidden_states)
  1482. loss = None
  1483. if labels is not None:
  1484. # Shift so that tokens < n predict n
  1485. if attention_mask is not None:
  1486. shift_attention_mask = attention_mask[..., 1:]
  1487. shift_logits = logits[..., :-1, :][shift_attention_mask != 0]
  1488. shift_labels = labels[..., 1:][shift_attention_mask != 0]
  1489. else:
  1490. shift_logits = logits[..., :-1, :]
  1491. shift_labels = labels[..., 1:]
  1492. # Flatten the tokens
  1493. loss = self.hf_compute_loss(
  1494. labels=tf.reshape(shift_labels, [-1]), logits=tf.reshape(shift_logits, [-1, shift_logits.shape[-1]])
  1495. )
  1496. if not return_dict:
  1497. output = (logits,) + outputs[1:]
  1498. return (loss,) + output if loss is not None else output
  1499. return TFIdeficsCausalLMOutputWithPast(
  1500. loss=loss,
  1501. logits=logits,
  1502. past_key_values=outputs.past_key_values,
  1503. hidden_states=outputs.hidden_states,
  1504. attentions=outputs.attentions,
  1505. image_hidden_states=outputs.image_hidden_states,
  1506. )
  1507. def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
  1508. image_hidden_states = kwargs.pop("image_hidden_states", None)
  1509. if image_hidden_states is not None:
  1510. if self.config.use_resampler:
  1511. kwargs["perceiver_embeddings"] = image_hidden_states
  1512. else:
  1513. kwargs["image_encoder_embeddings"] = image_hidden_states
  1514. kwargs["pixel_values"] = None
  1515. inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
  1516. unwanted_kwargs = ["token_type_ids"]
  1517. for kwarg in unwanted_kwargs:
  1518. inputs.pop(kwarg, None)
  1519. return inputs
  1520. @staticmethod
  1521. def _expand_inputs_for_generation(
  1522. *args,
  1523. **model_kwargs,
  1524. ):
  1525. return expand_inputs_for_generation(*args, **model_kwargs)
  1526. @staticmethod
  1527. def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder):
  1528. return update_model_kwargs_for_generation(outputs, model_kwargs)
  1529. @staticmethod
  1530. def _reorder_cache(past, beam_idx):
  1531. reordered_past = ()
  1532. for layer_past in past:
  1533. reordered_past += (tuple(tf.gather(past_state, beam_idx) for past_state in layer_past),)
  1534. return reordered_past
  1535. def build(self, input_shape=None):
  1536. if self.built:
  1537. return
  1538. self.built = True
  1539. if getattr(self, "model", None) is not None:
  1540. with tf.name_scope(self.model.name):
  1541. self.model.build(None)
  1542. if getattr(self, "lm_head", None) is not None:
  1543. with tf.name_scope(self.lm_head.name):
  1544. self.lm_head.build(None)