modeling_llava.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # coding=utf-8
  2. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Llava model."""
  16. from dataclasses import dataclass
  17. from typing import List, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_outputs import ModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import (
  26. add_start_docstrings,
  27. add_start_docstrings_to_model_forward,
  28. logging,
  29. replace_return_docstrings,
  30. )
  31. from ..auto import AutoModel, AutoModelForCausalLM
  32. from .configuration_llava import LlavaConfig
  33. logger = logging.get_logger(__name__)
  34. _CONFIG_FOR_DOC = "LlavaConfig"
  35. # Base docstring
  36. _CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
  37. @dataclass
  38. class LlavaCausalLMOutputWithPast(ModelOutput):
  39. """
  40. Base class for Llava causal language model (or autoregressive) outputs.
  41. Args:
  42. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  43. Language modeling loss (for next-token prediction).
  44. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  45. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  46. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  47. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  48. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  49. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  50. `past_key_values` input) to speed up sequential decoding.
  51. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  52. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  53. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  54. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  55. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  56. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  57. sequence_length)`.
  58. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  59. heads.
  60. image_hidden_states (`torch.FloatTensor`, *optional*):
  61. A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
  62. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  63. """
  64. loss: Optional[torch.FloatTensor] = None
  65. logits: torch.FloatTensor = None
  66. past_key_values: Optional[List[torch.FloatTensor]] = None
  67. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  68. attentions: Optional[Tuple[torch.FloatTensor]] = None
  69. image_hidden_states: Optional[torch.FloatTensor] = None
  70. class LlavaMultiModalProjector(nn.Module):
  71. def __init__(self, config: LlavaConfig):
  72. super().__init__()
  73. self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
  74. self.act = ACT2FN[config.projector_hidden_act]
  75. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
  76. def forward(self, image_features):
  77. hidden_states = self.linear_1(image_features)
  78. hidden_states = self.act(hidden_states)
  79. hidden_states = self.linear_2(hidden_states)
  80. return hidden_states
  81. LLAVA_START_DOCSTRING = r"""
  82. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  83. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  84. etc.)
  85. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  86. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  87. and behavior.
  88. Parameters:
  89. config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
  90. Model configuration class with all the parameters of the model. Initializing with a config file does not
  91. load the weights associated with the model, only the configuration. Check out the
  92. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  93. """
  94. @add_start_docstrings(
  95. "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
  96. LLAVA_START_DOCSTRING,
  97. )
  98. class LlavaPreTrainedModel(PreTrainedModel):
  99. config_class = LlavaConfig
  100. base_model_prefix = "model"
  101. supports_gradient_checkpointing = True
  102. _no_split_modules = ["LlavaVisionAttention"]
  103. _skip_keys_device_placement = "past_key_values"
  104. _supports_cache_class = True
  105. _supports_flash_attn_2 = True
  106. _supports_sdpa = True
  107. def _init_weights(self, module):
  108. # important: this ported version of Llava isn't meant for training from scratch - only
  109. # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
  110. # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
  111. std = (
  112. self.config.initializer_range
  113. if hasattr(self.config, "initializer_range")
  114. else self.config.text_config.initializer_range
  115. )
  116. if hasattr(module, "class_embedding"):
  117. module.class_embedding.data.normal_(mean=0.0, std=std)
  118. if isinstance(module, (nn.Linear, nn.Conv2d)):
  119. module.weight.data.normal_(mean=0.0, std=std)
  120. if module.bias is not None:
  121. module.bias.data.zero_()
  122. elif isinstance(module, nn.Embedding):
  123. module.weight.data.normal_(mean=0.0, std=std)
  124. if module.padding_idx is not None:
  125. module.weight.data[module.padding_idx].zero_()
  126. LLAVA_INPUTS_DOCSTRING = r"""
  127. Args:
  128. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  129. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  130. it.
  131. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  132. [`PreTrainedTokenizer.__call__`] for details.
  133. [What are input IDs?](../glossary#input-ids)
  134. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  135. The tensors corresponding to the input images. Pixel values can be obtained using
  136. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
  137. [`CLIPImageProcessor`] for processing images).
  138. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  139. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  140. - 1 for tokens that are **not masked**,
  141. - 0 for tokens that are **masked**.
  142. [What are attention masks?](../glossary#attention-mask)
  143. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  144. [`PreTrainedTokenizer.__call__`] for details.
  145. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  146. `past_key_values`).
  147. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  148. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  149. information on the default strategy.
  150. - 1 indicates the head is **not masked**,
  151. - 0 indicates the head is **masked**.
  152. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  153. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  154. config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  155. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  156. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  157. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  158. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  159. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  160. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  161. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  162. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  163. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  164. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  165. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  166. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  167. model's internal embedding lookup matrix.
  168. vision_feature_layer (`int`, *optional*, defaults to -2):
  169. The index of the layer to select the vision feature.
  170. vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
  171. The feature selection strategy used to select the vision feature from the vision backbone.
  172. Can be one of `"default"` or `"full"`.
  173. use_cache (`bool`, *optional*):
  174. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  175. `past_key_values`).
  176. output_attentions (`bool`, *optional*):
  177. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  178. tensors for more detail.
  179. output_hidden_states (`bool`, *optional*):
  180. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  181. more detail.
  182. return_dict (`bool`, *optional*):
  183. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  184. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  185. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  186. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  187. the complete sequence length.
  188. """
  189. @add_start_docstrings(
  190. """The LLAVA model which consists of a vision backbone and a language model.""",
  191. LLAVA_START_DOCSTRING,
  192. )
  193. class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
  194. def __init__(self, config: LlavaConfig):
  195. super().__init__(config)
  196. self.vision_tower = AutoModel.from_config(config.vision_config)
  197. self.multi_modal_projector = LlavaMultiModalProjector(config)
  198. self.vocab_size = config.text_config.vocab_size
  199. self.language_model = AutoModelForCausalLM.from_config(config.text_config)
  200. self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
  201. self.post_init()
  202. def get_input_embeddings(self):
  203. return self.language_model.get_input_embeddings()
  204. def set_input_embeddings(self, value):
  205. self.language_model.set_input_embeddings(value)
  206. def get_output_embeddings(self):
  207. return self.language_model.get_output_embeddings()
  208. def set_output_embeddings(self, new_embeddings):
  209. self.language_model.set_output_embeddings(new_embeddings)
  210. def set_decoder(self, decoder):
  211. self.language_model.set_decoder(decoder)
  212. def get_decoder(self):
  213. return self.language_model.get_decoder()
  214. def tie_weights(self):
  215. return self.language_model.tie_weights()
  216. def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
  217. model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
  218. # update vocab size
  219. self.config.text_config.vocab_size = model_embeds.num_embeddings
  220. self.vocab_size = model_embeds.num_embeddings
  221. return model_embeds
  222. def get_image_features(
  223. self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
  224. ):
  225. """
  226. Obtains image last hidden states from the vision tower and apply multimodal projection.
  227. Args:
  228. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  229. The tensors corresponding to the input images.
  230. vision_feature_layer (`int`):
  231. The index of the layer to select the vision feature.
  232. vision_feature_select_strategy (`str`):
  233. The feature selection strategy used to select the vision feature from the vision backbone.
  234. Can be one of `"default"` or `"full"`
  235. Returns:
  236. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  237. """
  238. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  239. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
  240. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  241. if vision_feature_select_strategy == "default":
  242. selected_image_feature = selected_image_feature[:, 1:]
  243. elif vision_feature_select_strategy == "full":
  244. selected_image_feature = selected_image_feature
  245. else:
  246. raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
  247. image_features = self.multi_modal_projector(selected_image_feature)
  248. return image_features
  249. def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
  250. num_images, num_image_patches, embed_dim = image_features.shape
  251. batch_size, sequence_length = input_ids.shape
  252. left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
  253. # 1. Create a mask to know where special image tokens are
  254. special_image_token_mask = input_ids == self.config.image_token_index
  255. num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
  256. # Compute the maximum embed dimension
  257. max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
  258. batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
  259. # 2. Compute the positions where text should be written
  260. # Calculate new positions for text tokens in merged image-text sequence.
  261. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
  262. # `torch.cumsum` computes how each image token shifts subsequent text token positions.
  263. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
  264. new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
  265. nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
  266. if left_padding:
  267. new_token_positions += nb_image_pad[:, None] # offset for left padding
  268. text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
  269. # 3. Create the full embedding, already padded to the maximum position
  270. final_embedding = torch.zeros(
  271. batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
  272. )
  273. final_attention_mask = torch.zeros(
  274. batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
  275. )
  276. if labels is not None:
  277. final_labels = torch.full(
  278. (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
  279. )
  280. # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
  281. # set the corresponding tensors into their correct target device.
  282. target_device = inputs_embeds.device
  283. batch_indices, non_image_indices, text_to_overwrite = (
  284. batch_indices.to(target_device),
  285. non_image_indices.to(target_device),
  286. text_to_overwrite.to(target_device),
  287. )
  288. attention_mask = attention_mask.to(target_device)
  289. # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
  290. # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
  291. final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
  292. final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
  293. if labels is not None:
  294. final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
  295. # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
  296. image_to_overwrite = torch.full(
  297. (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
  298. )
  299. image_to_overwrite[batch_indices, text_to_overwrite] = False
  300. image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
  301. if image_to_overwrite.sum() != image_features.shape[:-1].numel():
  302. raise ValueError(
  303. f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
  304. f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
  305. )
  306. final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
  307. final_attention_mask |= image_to_overwrite
  308. position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
  309. # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
  310. batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
  311. indices_to_mask = new_token_positions[batch_indices, pad_indices]
  312. final_embedding[batch_indices, indices_to_mask] = 0
  313. if labels is None:
  314. final_labels = None
  315. return final_embedding, final_attention_mask, final_labels, position_ids
  316. @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
  317. @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  318. def forward(
  319. self,
  320. input_ids: torch.LongTensor = None,
  321. pixel_values: torch.FloatTensor = None,
  322. attention_mask: Optional[torch.Tensor] = None,
  323. position_ids: Optional[torch.LongTensor] = None,
  324. past_key_values: Optional[List[torch.FloatTensor]] = None,
  325. inputs_embeds: Optional[torch.FloatTensor] = None,
  326. vision_feature_layer: Optional[int] = None,
  327. vision_feature_select_strategy: Optional[str] = None,
  328. labels: Optional[torch.LongTensor] = None,
  329. use_cache: Optional[bool] = None,
  330. output_attentions: Optional[bool] = None,
  331. output_hidden_states: Optional[bool] = None,
  332. return_dict: Optional[bool] = None,
  333. cache_position: Optional[torch.LongTensor] = None,
  334. num_logits_to_keep: int = 0,
  335. ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
  336. r"""
  337. Args:
  338. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  339. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  340. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  341. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  342. num_logits_to_keep (`int`, *optional*):
  343. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  344. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  345. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  346. Returns:
  347. Example:
  348. ```python
  349. >>> from PIL import Image
  350. >>> import requests
  351. >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
  352. >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
  353. >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
  354. >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
  355. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  356. >>> image = Image.open(requests.get(url, stream=True).raw)
  357. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  358. >>> # Generate
  359. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  360. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  361. "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
  362. ```"""
  363. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  364. output_hidden_states = (
  365. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  366. )
  367. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  368. vision_feature_layer = (
  369. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  370. )
  371. vision_feature_select_strategy = (
  372. vision_feature_select_strategy
  373. if vision_feature_select_strategy is not None
  374. else self.config.vision_feature_select_strategy
  375. )
  376. if (input_ids is None) ^ (inputs_embeds is not None):
  377. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  378. if pixel_values is not None and inputs_embeds is not None:
  379. raise ValueError(
  380. "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
  381. )
  382. legacy_processing = False
  383. if inputs_embeds is None:
  384. inputs_embeds = self.get_input_embeddings()(input_ids)
  385. # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
  386. # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
  387. # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
  388. legacy_processing = (
  389. (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
  390. ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
  391. image_features = None
  392. if pixel_values is not None:
  393. image_features = self.get_image_features(
  394. pixel_values=pixel_values,
  395. vision_feature_layer=vision_feature_layer,
  396. vision_feature_select_strategy=vision_feature_select_strategy,
  397. )
  398. if legacy_processing:
  399. logger.warning_once(
  400. "Expanding inputs for image tokens in LLaVa should be done in processing. "
  401. "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
  402. "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
  403. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  404. )
  405. # prefill stage vs decoding stage (legacy behavior copied)
  406. if input_ids.shape[1] != 1:
  407. inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
  408. image_features, inputs_embeds, input_ids, attention_mask, labels
  409. )
  410. cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
  411. else:
  412. # Retrieve the first layer to inspect the logits and mask out the hidden states
  413. # that are set to 0
  414. first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
  415. # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
  416. batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
  417. # Get the target length
  418. target_length = input_ids.shape[1]
  419. past_length = first_layer_past_key_value.shape[-1]
  420. extended_attention_mask = torch.ones(
  421. (attention_mask.shape[0], past_length),
  422. dtype=attention_mask.dtype,
  423. device=attention_mask.device,
  424. )
  425. # Filter out only the tokens that can be un-attended, this can happen
  426. # if one uses Llava + Fused modules where the cache on the
  427. # first iteration is already big enough, or if one passes custom cache
  428. valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
  429. new_batch_index = batch_index[valid_indices]
  430. new_non_attended_tokens = non_attended_tokens[valid_indices]
  431. # Zero-out the places where we don't need to attend
  432. extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
  433. attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
  434. position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
  435. cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
  436. # TODO: @raushan retain only the new behavior after v4.47
  437. elif image_features is not None:
  438. n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
  439. n_image_features = image_features.shape[0] * image_features.shape[1]
  440. if n_image_tokens != n_image_features:
  441. raise ValueError(
  442. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  443. )
  444. special_image_mask = (
  445. (input_ids == self.config.image_token_index)
  446. .unsqueeze(-1)
  447. .expand_as(inputs_embeds)
  448. .to(inputs_embeds.device)
  449. )
  450. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  451. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  452. outputs = self.language_model(
  453. attention_mask=attention_mask,
  454. position_ids=position_ids,
  455. past_key_values=past_key_values,
  456. inputs_embeds=inputs_embeds,
  457. use_cache=use_cache,
  458. output_attentions=output_attentions,
  459. output_hidden_states=output_hidden_states,
  460. return_dict=return_dict,
  461. cache_position=cache_position,
  462. num_logits_to_keep=num_logits_to_keep,
  463. )
  464. logits = outputs[0]
  465. loss = None
  466. if labels is not None:
  467. # Shift so that tokens < n predict n
  468. if attention_mask is not None:
  469. # we use the input attention mask to shift the logits and labels, because it is 2D.
  470. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
  471. shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
  472. shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
  473. shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
  474. else:
  475. shift_logits = logits[..., :-1, :].contiguous()
  476. shift_labels = labels[..., 1:].contiguous()
  477. # Flatten the tokens
  478. loss_fct = nn.CrossEntropyLoss()
  479. loss = loss_fct(
  480. shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
  481. )
  482. if not return_dict:
  483. output = (logits,) + outputs[1:]
  484. return (loss,) + output if loss is not None else output
  485. return LlavaCausalLMOutputWithPast(
  486. loss=loss,
  487. logits=logits,
  488. past_key_values=outputs.past_key_values,
  489. hidden_states=outputs.hidden_states,
  490. attentions=outputs.attentions,
  491. image_hidden_states=image_features if pixel_values is not None else None,
  492. )
  493. def prepare_inputs_for_generation(
  494. self,
  495. input_ids,
  496. past_key_values=None,
  497. inputs_embeds=None,
  498. pixel_values=None,
  499. attention_mask=None,
  500. cache_position=None,
  501. num_logits_to_keep=None,
  502. **kwargs,
  503. ):
  504. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  505. model_inputs = self.language_model.prepare_inputs_for_generation(
  506. input_ids,
  507. past_key_values=past_key_values,
  508. inputs_embeds=inputs_embeds,
  509. attention_mask=attention_mask,
  510. cache_position=cache_position,
  511. num_logits_to_keep=num_logits_to_keep,
  512. **kwargs,
  513. )
  514. if cache_position[0] == 0:
  515. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  516. # Otherwise we need pixel values to be passed to model
  517. model_inputs["pixel_values"] = pixel_values
  518. return model_inputs