modeling_idefics3.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318
  1. # coding=utf-8
  2. # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Idefics3 model."""
  16. from dataclasses import dataclass
  17. from typing import List, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ... import PreTrainedModel
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  27. from ...modeling_outputs import BaseModelOutput, ModelOutput
  28. from ...utils import (
  29. add_start_docstrings,
  30. add_start_docstrings_to_model_forward,
  31. is_flash_attn_2_available,
  32. is_flash_attn_greater_or_equal_2_10,
  33. logging,
  34. replace_return_docstrings,
  35. )
  36. from ..auto import AutoModel
  37. from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig
  38. if is_flash_attn_2_available():
  39. from ...modeling_flash_attention_utils import _flash_attention_forward
  40. logger = logging.get_logger(__name__)
  41. _CONFIG_FOR_DOC = "Idefics3Config"
  42. @dataclass
  43. class Idefics3BaseModelOutputWithPast(ModelOutput):
  44. """
  45. Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
  46. Args:
  47. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  48. Sequence of hidden-states at the output of the last layer of the model.
  49. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  50. hidden_size)` is output.
  51. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  52. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  53. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  54. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  55. encoder_sequence_length, embed_size_per_head)`.
  56. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  57. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  58. input) to speed up sequential decoding.
  59. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  60. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  61. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  62. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  63. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  64. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  65. sequence_length)`.
  66. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  67. heads.
  68. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  69. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  70. sequence_length, hidden_size)`.
  71. image_hidden_states of the model produced by the vision encoder
  72. """
  73. last_hidden_state: torch.FloatTensor = None
  74. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  75. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  76. attentions: Optional[Tuple[torch.FloatTensor]] = None
  77. image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  78. @dataclass
  79. class Idefics3CausalLMOutputWithPast(ModelOutput):
  80. """
  81. Base class for Idefics causal language model (or autoregressive) outputs.
  82. Args:
  83. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  84. Language modeling loss (for next-token prediction).
  85. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  86. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  87. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  88. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  89. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  90. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  91. `past_key_values` input) to speed up sequential decoding.
  92. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  93. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  94. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  95. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  96. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  97. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  98. sequence_length)`.
  99. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  100. heads.
  101. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  102. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  103. sequence_length, hidden_size)`.
  104. image_hidden_states of the model produced by the vision encoder
  105. """
  106. loss: Optional[torch.FloatTensor] = None
  107. logits: torch.FloatTensor = None
  108. past_key_values: Optional[List[torch.FloatTensor]] = None
  109. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  110. attentions: Optional[Tuple[torch.FloatTensor]] = None
  111. image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  112. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
  113. class Idefics3VisionEmbeddings(nn.Module):
  114. """
  115. This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
  116. resolution.
  117. The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
  118. which allows treating images in their native aspect ratio and without the need to resize them to the same
  119. fixed size. In particular, we start from the original pre-trained SigLIP model
  120. (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
  121. """
  122. def __init__(self, config: Idefics3VisionConfig):
  123. super().__init__()
  124. self.embed_dim = config.hidden_size
  125. self.image_size = config.image_size
  126. self.patch_size = config.patch_size
  127. self.patch_embedding = nn.Conv2d(
  128. in_channels=config.num_channels,
  129. out_channels=self.embed_dim,
  130. kernel_size=self.patch_size,
  131. stride=self.patch_size,
  132. padding="valid",
  133. )
  134. self.num_patches_per_side = self.image_size // self.patch_size
  135. self.num_patches = self.num_patches_per_side**2
  136. self.num_positions = self.num_patches
  137. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  138. def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
  139. batch_size, _, max_im_h, max_im_w = pixel_values.shape
  140. patch_embeds = self.patch_embedding(pixel_values)
  141. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  142. max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
  143. boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
  144. position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
  145. for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  146. nb_patches_h = p_attn_mask[:, 0].sum()
  147. nb_patches_w = p_attn_mask[0].sum()
  148. fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
  149. fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
  150. bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
  151. bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
  152. pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
  153. position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
  154. position_ids = position_ids.to(self.position_embedding.weight.device)
  155. embeddings = embeddings + self.position_embedding(position_ids)
  156. return embeddings
  157. # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
  158. class Idefics3VisionAttention(nn.Module):
  159. """Multi-headed attention from 'Attention Is All You Need' paper"""
  160. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  161. def __init__(self, config):
  162. super().__init__()
  163. self.config = config
  164. self.embed_dim = config.hidden_size
  165. self.num_heads = config.num_attention_heads
  166. self.head_dim = self.embed_dim // self.num_heads
  167. if self.head_dim * self.num_heads != self.embed_dim:
  168. raise ValueError(
  169. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  170. f" {self.num_heads})."
  171. )
  172. self.scale = self.head_dim**-0.5
  173. self.dropout = config.attention_dropout
  174. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  175. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  176. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  177. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  178. # Ignore copy
  179. self.is_causal = False
  180. def forward(
  181. self,
  182. hidden_states: torch.Tensor,
  183. attention_mask: Optional[torch.Tensor] = None,
  184. output_attentions: Optional[bool] = False,
  185. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  186. """Input shape: Batch x Time x Channel"""
  187. batch_size, q_len, _ = hidden_states.size()
  188. query_states = self.q_proj(hidden_states)
  189. key_states = self.k_proj(hidden_states)
  190. value_states = self.v_proj(hidden_states)
  191. query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  192. key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  193. value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  194. k_v_seq_len = key_states.shape[-2]
  195. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
  196. if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
  197. raise ValueError(
  198. f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  199. f" {attn_weights.size()}"
  200. )
  201. if attention_mask is not None:
  202. if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
  203. raise ValueError(
  204. f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
  205. )
  206. attn_weights = attn_weights + attention_mask
  207. # upcast attention to fp32
  208. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  209. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  210. attn_output = torch.matmul(attn_weights, value_states)
  211. if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
  212. raise ValueError(
  213. f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
  214. f" {attn_output.size()}"
  215. )
  216. attn_output = attn_output.transpose(1, 2).contiguous()
  217. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  218. attn_output = self.out_proj(attn_output)
  219. return attn_output, attn_weights
  220. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 with Idefics2->Idefics3
  221. class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
  222. """
  223. Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays
  224. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  225. flash attention and deal with padding tokens in case the input contains any of them.
  226. """
  227. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  228. def __init__(self, *args, **kwargs):
  229. super().__init__(*args, **kwargs)
  230. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  231. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  232. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  233. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  234. def forward(
  235. self,
  236. hidden_states: torch.Tensor,
  237. attention_mask: Optional[torch.LongTensor] = None,
  238. position_ids: Optional[torch.LongTensor] = None,
  239. past_key_value: Optional[Cache] = None,
  240. output_attentions: bool = False,
  241. use_cache: bool = False,
  242. **kwargs,
  243. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  244. output_attentions = False
  245. bsz, q_len, _ = hidden_states.size()
  246. query_states = self.q_proj(hidden_states)
  247. key_states = self.k_proj(hidden_states)
  248. value_states = self.v_proj(hidden_states)
  249. # Flash attention requires the input to have the shape
  250. # batch_size x seq_length x head_dim x hidden_dim
  251. # therefore we just need to keep the original shape
  252. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
  253. key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  254. value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  255. kv_seq_len = key_states.shape[-2]
  256. if past_key_value is not None:
  257. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  258. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  259. # to be able to avoid many of these transpose/reshape/view.
  260. key_states = key_states.transpose(1, 2)
  261. value_states = value_states.transpose(1, 2)
  262. dropout_rate = self.dropout if self.training else 0.0
  263. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  264. # therefore the input hidden states gets silently casted in float32. Hence, we need
  265. # cast them back in the correct dtype just to be sure everything works as expected.
  266. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  267. # in fp32. (Idefics3VisionRMSNorm handles it correctly)
  268. input_dtype = query_states.dtype
  269. if input_dtype == torch.float32:
  270. if torch.is_autocast_enabled():
  271. target_dtype = torch.get_autocast_gpu_dtype()
  272. # Handle the case where the model is quantized
  273. elif hasattr(self.config, "_pre_quantization_dtype"):
  274. target_dtype = self.config._pre_quantization_dtype
  275. else:
  276. target_dtype = self.q_proj.weight.dtype
  277. logger.warning_once(
  278. f"The input hidden states seems to be silently casted in float32, this might be related to"
  279. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  280. f" {target_dtype}."
  281. )
  282. query_states = query_states.to(target_dtype)
  283. key_states = key_states.to(target_dtype)
  284. value_states = value_states.to(target_dtype)
  285. attn_output = _flash_attention_forward(
  286. query_states,
  287. key_states,
  288. value_states,
  289. attention_mask,
  290. q_len,
  291. dropout=dropout_rate,
  292. is_causal=self.is_causal,
  293. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  294. )
  295. attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
  296. attn_output = self.out_proj(attn_output)
  297. if not output_attentions:
  298. attn_weights = None
  299. return attn_output, attn_weights
  300. IDEFICS_VISION_ATTENTION_CLASSES = {
  301. "eager": Idefics3VisionAttention,
  302. "flash_attention_2": Idefics3VisionFlashAttention2,
  303. }
  304. # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
  305. class Idefics3VisionMLP(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. self.config = config
  309. self.activation_fn = ACT2FN[config.hidden_act]
  310. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  311. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  312. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  313. hidden_states = self.fc1(hidden_states)
  314. hidden_states = self.activation_fn(hidden_states)
  315. hidden_states = self.fc2(hidden_states)
  316. return hidden_states
  317. class Idefics3SimpleMLP(nn.Module):
  318. def __init__(self, config):
  319. super().__init__()
  320. input_size = config.vision_config.hidden_size * (config.scale_factor**2)
  321. output_size = config.text_config.hidden_size
  322. self.proj = nn.Linear(input_size, output_size, bias=False)
  323. def forward(self, x):
  324. return self.proj(x)
  325. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
  326. class Idefics3EncoderLayer(nn.Module):
  327. def __init__(self, config: Idefics3VisionConfig):
  328. super().__init__()
  329. self.embed_dim = config.hidden_size
  330. self.self_attn = IDEFICS_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
  331. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  332. self.mlp = Idefics3VisionMLP(config)
  333. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  334. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
  335. def forward(
  336. self,
  337. hidden_states: torch.Tensor,
  338. attention_mask: torch.Tensor,
  339. output_attentions: Optional[bool] = False,
  340. ) -> Tuple[torch.FloatTensor]:
  341. """
  342. Args:
  343. hidden_states (`torch.FloatTensor`):
  344. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  345. attention_mask (`torch.FloatTensor`):
  346. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  347. output_attentions (`bool`, *optional*, defaults to `False`):
  348. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  349. returned tensors for more detail.
  350. """
  351. residual = hidden_states
  352. hidden_states = self.layer_norm1(hidden_states)
  353. hidden_states, attn_weights = self.self_attn(
  354. hidden_states=hidden_states,
  355. attention_mask=attention_mask,
  356. output_attentions=output_attentions,
  357. )
  358. hidden_states = residual + hidden_states
  359. residual = hidden_states
  360. hidden_states = self.layer_norm2(hidden_states)
  361. hidden_states = self.mlp(hidden_states)
  362. hidden_states = residual + hidden_states
  363. outputs = (hidden_states,)
  364. if output_attentions:
  365. outputs += (attn_weights,)
  366. return outputs
  367. # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
  368. class Idefics3Encoder(nn.Module):
  369. """
  370. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  371. [`Idefics3EncoderLayer`].
  372. Args:
  373. config: Idefics3Config
  374. """
  375. def __init__(self, config: Idefics3Config):
  376. super().__init__()
  377. self.config = config
  378. self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  379. self.gradient_checkpointing = False
  380. # Ignore copy
  381. def forward(
  382. self,
  383. inputs_embeds,
  384. attention_mask: Optional[torch.Tensor] = None,
  385. output_attentions: Optional[bool] = None,
  386. output_hidden_states: Optional[bool] = None,
  387. return_dict: Optional[bool] = None,
  388. ) -> Union[Tuple, BaseModelOutput]:
  389. r"""
  390. Args:
  391. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  392. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  393. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  394. than the model's internal embedding lookup matrix.
  395. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  396. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  397. - 1 for tokens that are **not masked**,
  398. - 0 for tokens that are **masked**.
  399. [What are attention masks?](../glossary#attention-mask)
  400. output_attentions (`bool`, *optional*):
  401. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  402. returned tensors for more detail.
  403. output_hidden_states (`bool`, *optional*):
  404. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  405. for more detail.
  406. return_dict (`bool`, *optional*):
  407. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  408. """
  409. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  410. output_hidden_states = (
  411. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  412. )
  413. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  414. encoder_states = () if output_hidden_states else None
  415. all_attentions = () if output_attentions else None
  416. hidden_states = inputs_embeds
  417. for encoder_layer in self.layers:
  418. if output_hidden_states:
  419. encoder_states = encoder_states + (hidden_states,)
  420. if self.gradient_checkpointing and self.training:
  421. layer_outputs = self._gradient_checkpointing_func(
  422. encoder_layer.__call__,
  423. hidden_states,
  424. attention_mask,
  425. output_attentions,
  426. )
  427. else:
  428. layer_outputs = encoder_layer(
  429. hidden_states,
  430. attention_mask,
  431. output_attentions=output_attentions,
  432. )
  433. hidden_states = layer_outputs[0]
  434. if output_attentions:
  435. all_attentions = all_attentions + (layer_outputs[1],)
  436. if output_hidden_states:
  437. encoder_states = encoder_states + (hidden_states,)
  438. if not return_dict:
  439. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  440. return BaseModelOutput(
  441. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  442. )
  443. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  444. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  445. """
  446. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  447. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  448. """
  449. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  450. if n_rep == 1:
  451. return hidden_states
  452. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  453. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  454. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
  455. class Idefics3RMSNorm(nn.Module):
  456. def __init__(self, hidden_size, eps=1e-6):
  457. """
  458. Idefics3RMSNorm is equivalent to T5LayerNorm
  459. """
  460. super().__init__()
  461. self.weight = nn.Parameter(torch.ones(hidden_size))
  462. self.variance_epsilon = eps
  463. def forward(self, hidden_states):
  464. input_dtype = hidden_states.dtype
  465. hidden_states = hidden_states.to(torch.float32)
  466. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  467. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  468. return self.weight * hidden_states.to(input_dtype)
  469. def extra_repr(self):
  470. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  471. class Idefics3Connector(nn.Module):
  472. def __init__(self, config):
  473. super().__init__()
  474. self.scale_factor = config.scale_factor
  475. self.modality_projection = Idefics3SimpleMLP(config)
  476. def pixel_shuffle(self, x, scale_factor=2):
  477. bsz, seq, embed_dim = x.size()
  478. height = width = int(seq**0.5)
  479. x = x.view(bsz, height, width, embed_dim)
  480. x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
  481. x = x.permute(0, 2, 1, 3)
  482. x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
  483. x = x.permute(0, 2, 1, 3)
  484. x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
  485. return x
  486. def forward(self, image_hidden_states):
  487. image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
  488. image_hidden_states = self.modality_projection(image_hidden_states)
  489. return image_hidden_states
  490. IDEFICS3_START_DOCSTRING = r"""
  491. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  492. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  493. etc.)
  494. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  495. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  496. and behavior.
  497. Parameters:
  498. config ([`Idefics3Config`] or [`Idefics3VisionConfig`]):
  499. Model configuration class with all the parameters of the model. Initializing with a config file does not
  500. load the weights associated with the model, only the configuration. Check out the
  501. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  502. """
  503. @add_start_docstrings(
  504. "The bare Idefics3 Model outputting raw hidden-states without any specific head on top.",
  505. IDEFICS3_START_DOCSTRING,
  506. )
  507. class Idefics3PreTrainedModel(PreTrainedModel):
  508. config_class = Idefics3Config
  509. base_model_prefix = "model"
  510. supports_gradient_checkpointing = True
  511. _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
  512. _skip_keys_device_placement = "past_key_values"
  513. _supports_flash_attn_2 = True
  514. _supports_sdpa = True
  515. _supports_cache_class = True
  516. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
  517. def _init_weights(self, module):
  518. std = (
  519. self.config.text_config.initializer_range
  520. if hasattr(self.config, "initializer_range")
  521. else self.config.text_config.initializer_range
  522. )
  523. if hasattr(module, "class_embedding"):
  524. module.class_embedding.data.normal_(mean=0.0, std=std)
  525. if isinstance(module, (nn.Linear, nn.Conv2d)):
  526. module.weight.data.normal_(mean=0.0, std=std)
  527. if module.bias is not None:
  528. module.bias.data.zero_()
  529. elif isinstance(module, nn.Embedding):
  530. module.weight.data.normal_(mean=0.0, std=std)
  531. if module.padding_idx is not None:
  532. module.weight.data[module.padding_idx].zero_()
  533. IDEFICS3_VISION_START_DOCSTRING = r"""
  534. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  535. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  536. etc.)
  537. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  538. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  539. and behavior.
  540. Parameters:
  541. config ([`Idefics3VisionConfig`]):
  542. Model configuration class with all the parameters of the model. Initializing with a config file does not
  543. load the weights associated with the model, only the configuration. Check out the
  544. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  545. """
  546. @add_start_docstrings(
  547. "The Idefics3 Vision Transformer Model outputting raw image embedding.",
  548. IDEFICS3_VISION_START_DOCSTRING,
  549. )
  550. class Idefics3VisionTransformer(Idefics3PreTrainedModel):
  551. config_class = Idefics3VisionConfig
  552. _supports_sdpa = False
  553. def __init__(self, config: Idefics3VisionConfig):
  554. super().__init__(config)
  555. embed_dim = config.hidden_size
  556. self.embeddings = Idefics3VisionEmbeddings(config)
  557. self.encoder = Idefics3Encoder(config)
  558. self.patch_size = config.patch_size
  559. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  560. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  561. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
  562. def get_input_embeddings(self):
  563. return self.embeddings
  564. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
  565. def set_input_embeddings(self, value):
  566. self.embeddings = value
  567. def forward(
  568. self,
  569. pixel_values,
  570. patch_attention_mask: Optional[torch.BoolTensor] = None,
  571. output_attentions: Optional[bool] = None,
  572. output_hidden_states: Optional[bool] = None,
  573. return_dict: Optional[bool] = None,
  574. ) -> Union[Tuple, BaseModelOutput]:
  575. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  576. output_hidden_states = (
  577. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  578. )
  579. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  580. batch_size = pixel_values.size(0)
  581. if patch_attention_mask is None:
  582. patch_size = self.patch_size
  583. patch_attention_mask = torch.ones(
  584. (
  585. batch_size,
  586. pixel_values.size(2) // patch_size,
  587. pixel_values.size(3) // patch_size,
  588. )
  589. )
  590. patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
  591. hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
  592. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  593. # The call to `_upad_input` in `_flash_attention_forward` is expensive
  594. # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
  595. # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
  596. if not torch.any(~patch_attention_mask):
  597. patch_attention_mask = None
  598. elif not self._use_flash_attention_2:
  599. patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
  600. encoder_outputs = self.encoder(
  601. inputs_embeds=hidden_states,
  602. attention_mask=patch_attention_mask,
  603. output_attentions=output_attentions,
  604. output_hidden_states=output_hidden_states,
  605. return_dict=return_dict,
  606. )
  607. last_hidden_state = encoder_outputs[0]
  608. last_hidden_state = self.post_layernorm(last_hidden_state)
  609. if not return_dict:
  610. return (last_hidden_state,) + encoder_outputs[1:]
  611. return BaseModelOutput(
  612. last_hidden_state=last_hidden_state,
  613. hidden_states=encoder_outputs.hidden_states,
  614. attentions=encoder_outputs.attentions,
  615. )
  616. IDEFICS3_INPUTS_DOCSTRING = r"""
  617. Args:
  618. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  619. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  620. it.
  621. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  622. [`PreTrainedTokenizer.__call__`] for details.
  623. [What are input IDs?](../glossary#input-ids)
  624. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  625. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  626. - 1 for tokens that are **not masked**,
  627. - 0 for tokens that are **masked**.
  628. [What are attention masks?](../glossary#attention-mask)
  629. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  630. [`PreTrainedTokenizer.__call__`] for details.
  631. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  632. `past_key_values`).
  633. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  634. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  635. information on the default strategy.
  636. - 1 indicates the head is **not masked**,
  637. - 0 indicates the head is **masked**.
  638. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  639. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  640. config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  641. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  642. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  643. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  644. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  645. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  646. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  647. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  648. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  649. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  650. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  651. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  652. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  653. model's internal embedding lookup matrix.
  654. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  655. The tensors corresponding to the input images. Pixel values can be obtained using
  656. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
  657. [`CLIPImageProcessor`] for processing images).
  658. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  659. Mask to avoid performing attention on padding pixel indices.
  660. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  661. The hidden states of the image encoder after modality projection.
  662. use_cache (`bool`, *optional*):
  663. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  664. `past_key_values`).
  665. output_attentions (`bool`, *optional*):
  666. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  667. tensors for more detail.
  668. output_hidden_states (`bool`, *optional*):
  669. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  670. more detail.
  671. return_dict (`bool`, *optional*):
  672. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  673. """
  674. @add_start_docstrings(
  675. """Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder""",
  676. IDEFICS3_START_DOCSTRING,
  677. )
  678. class Idefics3Model(Idefics3PreTrainedModel):
  679. def __init__(self, config: Idefics3Config):
  680. super().__init__(config)
  681. self.padding_idx = self.config.text_config.pad_token_id
  682. self.vocab_size = self.config.text_config.vocab_size
  683. self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
  684. self.connector = Idefics3Connector(config)
  685. self.text_model = AutoModel.from_config(config.text_config)
  686. self.image_seq_len = int(
  687. ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
  688. )
  689. self.image_token_id = self.config.image_token_id
  690. self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
  691. self.post_init()
  692. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
  693. def enable_input_require_grads(self):
  694. """
  695. Enables the gradients for the input embeddings.
  696. This is useful for lora when using gradient checkpointing.
  697. c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
  698. Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
  699. """
  700. def get_lowest_module(module):
  701. if len(list(module.children())) == 0:
  702. # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
  703. return module
  704. else:
  705. # Recursively call the function on each child module
  706. return get_lowest_module(list(module.children())[0])
  707. def make_inputs_require_grads(module, input, output):
  708. output.requires_grad_(True)
  709. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  710. self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
  711. make_inputs_require_grads
  712. )
  713. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
  714. def disable_input_require_grads(self):
  715. self._text_require_grads_hook.remove()
  716. self._vision_require_grads_hook.remove()
  717. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
  718. def get_input_embeddings(self):
  719. return self.text_model.get_input_embeddings()
  720. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
  721. def set_input_embeddings(self, value):
  722. self.text_model.set_input_embeddings(value)
  723. def inputs_merger(
  724. self,
  725. input_ids: torch.LongTensor,
  726. inputs_embeds: Optional[torch.Tensor],
  727. image_hidden_states: Optional[torch.Tensor],
  728. ):
  729. """
  730. This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
  731. The merging happens as follows:
  732. - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
  733. - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
  734. We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
  735. - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
  736. - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
  737. """
  738. num_images, _, vision_hidden_size = image_hidden_states.shape
  739. special_image_token_mask = input_ids == self.image_token_id
  740. # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
  741. new_inputs_embeds = inputs_embeds.clone()
  742. reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size)
  743. # cast to the dtype of the input_embeds to support quantized models
  744. reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.dtype)
  745. new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
  746. return new_inputs_embeds
  747. @add_start_docstrings_to_model_forward(
  748. """
  749. Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
  750. the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
  751. max_num_images is the maximum number of images among the batch_size samples in the batch.
  752. Padding images are not needed beyond padding the pixel_values at the entrance of the model.
  753. For efficiency, we only pass through the vision_model's forward the real images by
  754. discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
  755. image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
  756. """,
  757. IDEFICS3_INPUTS_DOCSTRING,
  758. )
  759. def forward(
  760. self,
  761. input_ids: torch.LongTensor = None,
  762. attention_mask: Optional[torch.Tensor] = None,
  763. position_ids: Optional[torch.LongTensor] = None,
  764. past_key_values: Optional[List[torch.FloatTensor]] = None,
  765. inputs_embeds: Optional[torch.FloatTensor] = None,
  766. pixel_values: Optional[torch.FloatTensor] = None,
  767. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  768. image_hidden_states: Optional[torch.FloatTensor] = None,
  769. use_cache: Optional[bool] = None,
  770. output_attentions: Optional[bool] = None,
  771. output_hidden_states: Optional[bool] = None,
  772. return_dict: Optional[bool] = None,
  773. ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
  774. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  775. output_hidden_states = (
  776. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  777. )
  778. use_cache = use_cache if use_cache is not None else self.config.use_cache
  779. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  780. if self.training and self.text_model.gradient_checkpointing and use_cache:
  781. logger.warning_once(
  782. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  783. )
  784. use_cache = False
  785. # retrieve input_ids and inputs_embeds
  786. if input_ids is not None:
  787. batch_size, seq_length = input_ids.shape
  788. elif inputs_embeds is not None:
  789. batch_size, seq_length, _ = inputs_embeds.shape
  790. else:
  791. raise ValueError("You have to specify either input_ids or inputs_embeds")
  792. past_seen_tokens = 0
  793. if use_cache:
  794. if past_key_values is None:
  795. past_key_values = DynamicCache()
  796. past_seen_tokens = past_key_values.get_seq_length()
  797. if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
  798. raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
  799. if inputs_embeds is None:
  800. inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
  801. # START VISUAL INPUTS INTEGRATION
  802. if pixel_values is not None and image_hidden_states is not None:
  803. raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
  804. elif pixel_values is not None:
  805. batch_size, num_images, num_channels, height, width = pixel_values.shape
  806. pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
  807. pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
  808. # Remove padding images - padding images are full 0.
  809. nb_values_per_image = pixel_values.shape[1:].numel()
  810. real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
  811. pixel_values = pixel_values[real_images_inds].contiguous()
  812. # Handle the vision attention mask
  813. if pixel_attention_mask is None:
  814. pixel_attention_mask = torch.ones(
  815. size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
  816. dtype=torch.bool,
  817. device=pixel_values.device,
  818. )
  819. else:
  820. # Remove padding images from the mask
  821. pixel_attention_mask = pixel_attention_mask.view(
  822. batch_size * num_images, *pixel_attention_mask.shape[2:]
  823. )
  824. pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
  825. patch_size = self.config.vision_config.patch_size
  826. patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
  827. patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
  828. patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  829. # Get sequence from the vision encoder
  830. image_hidden_states = self.vision_model(
  831. pixel_values=pixel_values,
  832. patch_attention_mask=patch_attention_mask,
  833. ).last_hidden_state
  834. # Modality projection & resampling
  835. image_hidden_states = self.connector(image_hidden_states)
  836. elif image_hidden_states is not None:
  837. image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
  838. if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
  839. # When we generate, we don't want to replace the potential image_token_id that we generated by images
  840. # that simply don't exist
  841. inputs_embeds = self.inputs_merger(
  842. input_ids=input_ids,
  843. inputs_embeds=inputs_embeds,
  844. image_hidden_states=image_hidden_states,
  845. )
  846. outputs = self.text_model(
  847. inputs_embeds=inputs_embeds,
  848. attention_mask=attention_mask,
  849. position_ids=position_ids,
  850. past_key_values=past_key_values,
  851. use_cache=use_cache,
  852. output_attentions=output_attentions,
  853. output_hidden_states=output_hidden_states,
  854. return_dict=return_dict,
  855. )
  856. if not return_dict:
  857. return tuple(v for v in [*outputs, image_hidden_states] if v is not None)
  858. return Idefics3BaseModelOutputWithPast(
  859. last_hidden_state=outputs.last_hidden_state,
  860. past_key_values=outputs.past_key_values,
  861. hidden_states=outputs.hidden_states,
  862. attentions=outputs.attentions,
  863. image_hidden_states=image_hidden_states,
  864. )
  865. @add_start_docstrings(
  866. """The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
  867. IDEFICS3_START_DOCSTRING,
  868. )
  869. class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
  870. _tied_weights_keys = ["lm_head.weight"]
  871. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
  872. def __init__(self, config):
  873. super().__init__(config)
  874. self.model = Idefics3Model(config)
  875. self.image_token_id = self.config.image_token_id
  876. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  877. self.vocab_size = config.text_config.vocab_size
  878. # Initialize weights and apply final processing
  879. self.post_init()
  880. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
  881. def enable_input_require_grads(self):
  882. """
  883. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  884. the model weights fixed.
  885. """
  886. def make_inputs_require_grads(module, input, output):
  887. output.requires_grad_(True)
  888. self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  889. self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
  890. make_inputs_require_grads
  891. )
  892. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
  893. def disable_input_require_grads(self):
  894. self._text_require_grads_hook.remove()
  895. self._vision_require_grads_hook.remove()
  896. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
  897. def get_input_embeddings(self):
  898. return self.model.text_model.get_input_embeddings()
  899. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
  900. def set_input_embeddings(self, value):
  901. self.model.text_model.set_input_embeddings(value)
  902. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_output_embeddings
  903. def get_output_embeddings(self):
  904. return self.lm_head
  905. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_output_embeddings
  906. def set_output_embeddings(self, new_embeddings):
  907. self.lm_head = new_embeddings
  908. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.tie_weights
  909. def tie_weights(self):
  910. """
  911. Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of DecoupledLinear and DecoupledEmbedding.
  912. """
  913. output_embeddings = self.get_output_embeddings()
  914. input_embeddings = self.get_input_embeddings()
  915. if getattr(self.config, "tie_word_embeddings", True):
  916. output_embeddings.weight = input_embeddings.weight
  917. @add_start_docstrings_to_model_forward(IDEFICS3_INPUTS_DOCSTRING)
  918. @replace_return_docstrings(output_type=Idefics3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  919. def forward(
  920. self,
  921. input_ids: torch.LongTensor = None,
  922. attention_mask: Optional[torch.Tensor] = None,
  923. position_ids: Optional[torch.LongTensor] = None,
  924. past_key_values: Optional[List[torch.FloatTensor]] = None,
  925. inputs_embeds: Optional[torch.FloatTensor] = None,
  926. pixel_values: Optional[torch.FloatTensor] = None,
  927. pixel_attention_mask: Optional[torch.BoolTensor] = None,
  928. image_hidden_states: Optional[torch.FloatTensor] = None,
  929. labels: Optional[torch.LongTensor] = None,
  930. use_cache: Optional[bool] = None,
  931. output_attentions: Optional[bool] = None,
  932. output_hidden_states: Optional[bool] = None,
  933. return_dict: Optional[bool] = None,
  934. ) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
  935. r"""
  936. Args:
  937. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  938. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  939. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
  940. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  941. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  942. Returns:
  943. Example:
  944. ```python
  945. >>> import requests
  946. >>> import torch
  947. >>> from PIL import Image
  948. >>> from io import BytesIO
  949. >>> from transformers import AutoProcessor, AutoModelForVision2Seq
  950. >>> from transformers.image_utils import load_image
  951. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  952. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  953. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  954. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  955. >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
  956. >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", torch_dtype=torch.bfloat16, device_map="auto")
  957. >>> # Create inputs
  958. >>> messages = [
  959. ... {
  960. ... "role": "user",
  961. ... "content": [
  962. ... {"type": "image"},
  963. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  964. ... {"type": "image"},
  965. ... {"type": "text", "text": "What can we see in this image?"},
  966. ... ]
  967. ... },
  968. ... {
  969. ... "role": "user",
  970. ... "content": [
  971. ... {"type": "image"},
  972. ... {"type": "text", "text": "In which city is that bridge located?"},
  973. ... ]
  974. ... }
  975. ... ]
  976. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  977. >>> images = [[image1, image2], [image3]]
  978. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  979. >>> # Generate
  980. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  981. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  982. >>> print(generated_texts[0])
  983. Assistant: There are buildings, trees, lights, and water visible in this image.
  984. >>> print(generated_texts[1])
  985. Assistant: The bridge is in San Francisco.
  986. ```"""
  987. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  988. output_hidden_states = (
  989. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  990. )
  991. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  992. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  993. outputs = self.model(
  994. input_ids=input_ids,
  995. attention_mask=attention_mask,
  996. position_ids=position_ids,
  997. past_key_values=past_key_values,
  998. inputs_embeds=inputs_embeds,
  999. pixel_values=pixel_values,
  1000. pixel_attention_mask=pixel_attention_mask,
  1001. image_hidden_states=image_hidden_states,
  1002. use_cache=use_cache,
  1003. output_attentions=output_attentions,
  1004. output_hidden_states=output_hidden_states,
  1005. return_dict=return_dict,
  1006. )
  1007. hidden_states = outputs[0]
  1008. logits = self.lm_head(hidden_states)
  1009. loss = None
  1010. if labels is not None:
  1011. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1012. logits = logits.float()
  1013. labels = labels.to(logits.device)
  1014. # Shift so that tokens < n predict n
  1015. if attention_mask is not None:
  1016. # we use the input attention mask to shift the logits and labels, because it is 2D.
  1017. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  1018. shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
  1019. shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
  1020. shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
  1021. else:
  1022. shift_logits = logits[..., :-1, :].contiguous()
  1023. shift_labels = labels[..., 1:].contiguous()
  1024. # Flatten the tokens
  1025. loss_fct = CrossEntropyLoss()
  1026. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  1027. if not return_dict:
  1028. output = (logits,) + outputs[1:]
  1029. return (loss,) + output if loss is not None else output
  1030. return Idefics3CausalLMOutputWithPast(
  1031. loss=loss,
  1032. logits=logits,
  1033. past_key_values=outputs.past_key_values,
  1034. hidden_states=outputs.hidden_states,
  1035. attentions=outputs.attentions,
  1036. image_hidden_states=outputs.image_hidden_states,
  1037. )
  1038. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
  1039. def prepare_inputs_for_generation(
  1040. self,
  1041. input_ids,
  1042. past_key_values=None,
  1043. attention_mask=None,
  1044. inputs_embeds=None,
  1045. cache_position=None,
  1046. pixel_values=None,
  1047. pixel_attention_mask=None,
  1048. image_hidden_states=None,
  1049. num_logits_to_keep=None,
  1050. **kwargs,
  1051. ):
  1052. # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
  1053. # precedence is moved to the model, we can remove this fn)
  1054. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1055. if past_key_values is not None:
  1056. if inputs_embeds is not None: # Exception 1
  1057. input_ids = input_ids[:, -cache_position.shape[0] :]
  1058. elif input_ids.shape[1] != cache_position.shape[0]:
  1059. input_ids = input_ids[:, cache_position]
  1060. position_ids = kwargs.get("position_ids", None)
  1061. if attention_mask is not None and position_ids is None:
  1062. # create position_ids on the fly for batch generation
  1063. position_ids = attention_mask.long().cumsum(-1) - 1
  1064. position_ids.masked_fill_(attention_mask == 0, 1)
  1065. if past_key_values:
  1066. position_ids = position_ids[:, -input_ids.shape[1] :]
  1067. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1068. # but IDEFICS requires noth ids and embeds to be present
  1069. if inputs_embeds is not None and cache_position[0] == 0:
  1070. model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
  1071. else:
  1072. # The clone here is for the same reason as for `position_ids`.
  1073. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
  1074. if num_logits_to_keep is not None:
  1075. model_inputs["num_logits_to_keep"] = num_logits_to_keep
  1076. if image_hidden_states is not None:
  1077. pixel_values = None
  1078. pixel_attention_mask = None
  1079. else:
  1080. pixel_values = pixel_values
  1081. pixel_attention_mask = pixel_attention_mask
  1082. model_inputs.update(
  1083. {
  1084. "position_ids": position_ids,
  1085. "past_key_values": past_key_values,
  1086. "use_cache": kwargs.get("use_cache"),
  1087. "attention_mask": attention_mask,
  1088. "pixel_values": pixel_values,
  1089. "pixel_attention_mask": pixel_attention_mask,
  1090. "image_hidden_states": image_hidden_states,
  1091. }
  1092. )
  1093. return model_inputs
  1094. # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
  1095. def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
  1096. model_kwargs = super()._update_model_kwargs_for_generation(
  1097. outputs=outputs,
  1098. model_kwargs=model_kwargs,
  1099. is_encoder_decoder=is_encoder_decoder,
  1100. **kwargs,
  1101. )
  1102. # Get the precomputed image_hidden_states
  1103. model_kwargs["image_hidden_states"] = outputs.image_hidden_states
  1104. return model_kwargs