modeling_instructblip.py 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651
  1. # coding=utf-8
  2. # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch InstructBLIP model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...generation import GenerationMixin
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. BaseModelOutputWithPooling,
  29. BaseModelOutputWithPoolingAndCrossAttentions,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import (
  34. ModelOutput,
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. logging,
  38. replace_return_docstrings,
  39. torch_int,
  40. )
  41. from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
  42. from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
  43. logger = logging.get_logger(__name__)
  44. _CHECKPOINT_FOR_DOC = "Salesforce/instructblip-flan-t5-xl"
  45. @dataclass
  46. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
  47. class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
  48. """
  49. Class defining the outputs of [`InstructBlipForConditionalGeneration`].
  50. Args:
  51. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  52. Language modeling loss from the language model.
  53. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  54. Prediction scores of the language modeling head of the language model.
  55. vision_outputs (`BaseModelOutputWithPooling`):
  56. Outputs of the vision encoder.
  57. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  58. Outputs of the Q-Former (Querying Transformer).
  59. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  60. Outputs of the language model.
  61. """
  62. loss: Optional[Tuple[torch.FloatTensor]] = None
  63. logits: Optional[Tuple[torch.FloatTensor]] = None
  64. vision_outputs: Optional[torch.FloatTensor] = None
  65. qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
  66. language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
  67. def to_tuple(self) -> Tuple[Any]:
  68. return tuple(
  69. self[k]
  70. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  71. else getattr(self, k).to_tuple()
  72. for k in self.keys()
  73. )
  74. # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
  75. class InstructBlipVisionEmbeddings(nn.Module):
  76. def __init__(self, config: InstructBlipVisionConfig):
  77. super().__init__()
  78. self.config = config
  79. self.embed_dim = config.hidden_size
  80. self.image_size = config.image_size
  81. self.patch_size = config.patch_size
  82. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  83. self.patch_embedding = nn.Conv2d(
  84. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  85. )
  86. self.num_patches = (self.image_size // self.patch_size) ** 2
  87. self.num_positions = self.num_patches + 1
  88. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  89. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  90. """
  91. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  92. images. This method is also adapted to support torch.jit tracing.
  93. Adapted from:
  94. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  95. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  96. """
  97. num_patches = embeddings.shape[1] - 1
  98. num_positions = self.position_embedding.shape[1] - 1
  99. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  100. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  101. return self.position_embedding
  102. class_pos_embed = self.position_embedding[:, :1]
  103. patch_pos_embed = self.position_embedding[:, 1:]
  104. dim = embeddings.shape[-1]
  105. new_height = height // self.patch_size
  106. new_width = width // self.patch_size
  107. sqrt_num_positions = torch_int(num_positions**0.5)
  108. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  109. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  110. patch_pos_embed = nn.functional.interpolate(
  111. patch_pos_embed,
  112. size=(new_height, new_width),
  113. mode="bicubic",
  114. align_corners=False,
  115. )
  116. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  117. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  118. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  119. batch_size, _, height, width = pixel_values.shape
  120. target_dtype = self.patch_embedding.weight.dtype
  121. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  122. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  123. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  124. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  125. if interpolate_pos_encoding:
  126. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  127. else:
  128. position_embedding = self.position_embedding
  129. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  130. return embeddings
  131. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
  132. class InstructBlipAttention(nn.Module):
  133. """Multi-headed attention from 'Attention Is All You Need' paper"""
  134. def __init__(self, config):
  135. super().__init__()
  136. self.config = config
  137. self.embed_dim = config.hidden_size
  138. self.num_heads = config.num_attention_heads
  139. self.head_dim = self.embed_dim // self.num_heads
  140. if self.head_dim * self.num_heads != self.embed_dim:
  141. raise ValueError(
  142. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  143. f" {self.num_heads})."
  144. )
  145. self.scale = self.head_dim**-0.5
  146. self.dropout = nn.Dropout(config.attention_dropout)
  147. # small tweak here compared to CLIP, no bias here
  148. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  149. if config.qkv_bias:
  150. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  151. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  152. else:
  153. q_bias = None
  154. v_bias = None
  155. if q_bias is not None:
  156. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  157. self.qkv.bias = nn.Parameter(qkv_bias)
  158. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  159. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  160. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  161. def forward(
  162. self,
  163. hidden_states: torch.Tensor,
  164. head_mask: Optional[torch.Tensor] = None,
  165. output_attentions: Optional[bool] = False,
  166. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  167. """Input shape: Batch x Time x Channel"""
  168. bsz, tgt_len, embed_dim = hidden_states.size()
  169. mixed_qkv = self.qkv(hidden_states)
  170. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  171. 2, 0, 3, 1, 4
  172. )
  173. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  174. # Take the dot product between "query" and "key" to get the raw attention scores.
  175. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
  176. attention_scores = attention_scores * self.scale
  177. # Normalize the attention scores to probabilities.
  178. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  179. # This is actually dropping out entire tokens to attend to, which might
  180. # seem a bit unusual, but is taken from the original Transformer paper.
  181. attention_probs = self.dropout(attention_probs)
  182. # Mask heads if we want to
  183. if head_mask is not None:
  184. attention_probs = attention_probs * head_mask
  185. context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
  186. new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
  187. context_layer = context_layer.reshape(new_context_layer_shape)
  188. output = self.projection(context_layer)
  189. outputs = (output, attention_probs) if output_attentions else (output, None)
  190. return outputs
  191. # Copied from transformers.models.blip.modeling_blip.BlipMLP
  192. class InstructBlipMLP(nn.Module):
  193. def __init__(self, config):
  194. super().__init__()
  195. self.config = config
  196. self.activation_fn = ACT2FN[config.hidden_act]
  197. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  198. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  199. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  200. hidden_states = self.fc1(hidden_states)
  201. hidden_states = self.activation_fn(hidden_states)
  202. hidden_states = self.fc2(hidden_states)
  203. return hidden_states
  204. # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
  205. class InstructBlipEncoderLayer(nn.Module):
  206. def __init__(self, config: InstructBlipConfig):
  207. super().__init__()
  208. self.embed_dim = config.hidden_size
  209. self.self_attn = InstructBlipAttention(config)
  210. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  211. self.mlp = InstructBlipMLP(config)
  212. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  213. def forward(
  214. self,
  215. hidden_states: torch.Tensor,
  216. attention_mask: torch.Tensor,
  217. output_attentions: Optional[bool] = False,
  218. ) -> Tuple[torch.FloatTensor]:
  219. """
  220. Args:
  221. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  222. attention_mask (`torch.FloatTensor`): attention mask of size
  223. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  224. `(config.encoder_attention_heads,)`.
  225. output_attentions (`bool`, *optional*):
  226. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  227. returned tensors for more detail.
  228. """
  229. residual = hidden_states
  230. hidden_states = self.layer_norm1(hidden_states)
  231. hidden_states, attn_weights = self.self_attn(
  232. hidden_states=hidden_states,
  233. head_mask=attention_mask,
  234. output_attentions=output_attentions,
  235. )
  236. hidden_states = hidden_states + residual
  237. residual = hidden_states
  238. hidden_states = self.layer_norm2(hidden_states)
  239. hidden_states = self.mlp(hidden_states)
  240. hidden_states = hidden_states + residual
  241. outputs = (hidden_states,)
  242. if output_attentions:
  243. outputs += (attn_weights,)
  244. return outputs
  245. class InstructBlipPreTrainedModel(PreTrainedModel):
  246. """
  247. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  248. models.
  249. """
  250. config_class = InstructBlipConfig
  251. base_model_prefix = "blip"
  252. supports_gradient_checkpointing = True
  253. _no_split_modules = [
  254. "InstructBlipQFormerEmbeddings",
  255. "InstructBlipAttention",
  256. "InstructBlipQFormerMultiHeadAttention",
  257. "InstructBlipQFormerSelfOutput",
  258. ]
  259. _keep_in_fp32_modules = []
  260. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
  261. def _init_weights(self, module):
  262. """Initialize the weights"""
  263. factor = self.config.initializer_range
  264. if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
  265. module.weight.data.normal_(mean=0.0, std=factor)
  266. if hasattr(module, "bias") and module.bias is not None:
  267. module.bias.data.zero_()
  268. if isinstance(module, InstructBlipVisionEmbeddings):
  269. if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVisionConfig):
  270. factor = self.config.vision_config.initializer_range
  271. nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  272. nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  273. elif isinstance(module, nn.LayerNorm):
  274. module.bias.data.zero_()
  275. module.weight.data.fill_(1.0)
  276. elif isinstance(module, nn.Linear) and module.bias is not None:
  277. module.bias.data.zero_()
  278. INSTRUCTBLIP_START_DOCSTRING = r"""
  279. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  280. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  281. etc.)
  282. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  283. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  284. and behavior.
  285. Parameters:
  286. config ([`InstructBlipConfig`]): Model configuration class with all the parameters of the model.
  287. Initializing with a config file does not load the weights associated with the model, only the
  288. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  289. """
  290. INSTRUCTBLIP_VISION_INPUTS_DOCSTRING = r"""
  291. Args:
  292. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  293. Pixel values. Pixel values can be obtained using [`InstructBlipProcessor`]. See
  294. [`InstructBlipProcessor.__call__`] for details.
  295. output_attentions (`bool`, *optional*):
  296. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  297. tensors for more detail.
  298. output_hidden_states (`bool`, *optional*):
  299. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  300. more detail.
  301. return_dict (`bool`, *optional*):
  302. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  303. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  304. Whether to interpolate the pre-trained position encodings.
  305. """
  306. INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
  307. Args:
  308. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  309. Pixel values. Pixel values can be obtained using [`InstructBlipProcessor`]. See
  310. [`InstructBlipProcessor.__call__`] for details.
  311. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  312. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  313. to serve as text prompt, which the Q-Former model will encode.
  314. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  315. details.
  316. [What are input IDs?](../glossary#input-ids)
  317. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  318. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  319. - 1 for tokens that are **not masked**,
  320. - 0 for tokens that are **masked**.
  321. [What are attention masks?](../glossary#attention-mask)
  322. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  323. Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
  324. provided to serve as text prompt, which the language model can continue.
  325. Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
  326. details.
  327. [What are input IDs?](../glossary#input-ids)
  328. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  329. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  330. - 1 for tokens that are **not masked**,
  331. - 0 for tokens that are **masked**.
  332. [What are attention masks?](../glossary#attention-mask)
  333. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  334. Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
  335. encoder-decoder language model (like T5) is used.
  336. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  337. [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
  338. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  339. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  340. be used by default.
  341. Only relevant in case an encoder-decoder language model (like T5) is used.
  342. output_attentions (`bool`, *optional*):
  343. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  344. tensors for more detail.
  345. output_hidden_states (`bool`, *optional*):
  346. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  347. more detail.
  348. return_dict (`bool`, *optional*):
  349. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  350. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  351. Whether to interpolate the pre-trained position encodings.
  352. """
  353. # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
  354. class InstructBlipEncoder(nn.Module):
  355. """
  356. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  357. [`InstructBlipEncoderLayer`].
  358. Args:
  359. config (`InstructBlipConfig`):
  360. The corresponding vision configuration for the `InstructBlipEncoder`.
  361. """
  362. def __init__(self, config: InstructBlipConfig):
  363. super().__init__()
  364. self.config = config
  365. self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  366. self.gradient_checkpointing = False
  367. def forward(
  368. self,
  369. inputs_embeds,
  370. attention_mask: Optional[torch.Tensor] = None,
  371. output_attentions: Optional[bool] = None,
  372. output_hidden_states: Optional[bool] = None,
  373. return_dict: Optional[bool] = None,
  374. ) -> Union[Tuple, BaseModelOutput]:
  375. r"""
  376. Args:
  377. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  378. Embedded representation of the inputs. Should be float, not int tokens.
  379. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  380. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  381. - 1 for tokens that are **not masked**,
  382. - 0 for tokens that are **masked**.
  383. [What are attention masks?](../glossary#attention-mask)
  384. output_attentions (`bool`, *optional*):
  385. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  386. returned tensors for more detail.
  387. output_hidden_states (`bool`, *optional*):
  388. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  389. for more detail.
  390. return_dict (`bool`, *optional*):
  391. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  392. """
  393. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  394. output_hidden_states = (
  395. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  396. )
  397. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  398. encoder_states = () if output_hidden_states else None
  399. all_attentions = () if output_attentions else None
  400. hidden_states = inputs_embeds
  401. for idx, encoder_layer in enumerate(self.layers):
  402. if output_hidden_states:
  403. encoder_states = encoder_states + (hidden_states,)
  404. if self.gradient_checkpointing and self.training:
  405. layer_outputs = self._gradient_checkpointing_func(
  406. encoder_layer.__call__,
  407. hidden_states,
  408. attention_mask,
  409. output_attentions,
  410. )
  411. else:
  412. layer_outputs = encoder_layer(
  413. hidden_states,
  414. attention_mask,
  415. output_attentions=output_attentions,
  416. )
  417. hidden_states = layer_outputs[0]
  418. if output_attentions:
  419. all_attentions = all_attentions + (layer_outputs[1],)
  420. if output_hidden_states:
  421. encoder_states = encoder_states + (hidden_states,)
  422. if not return_dict:
  423. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  424. return BaseModelOutput(
  425. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  426. )
  427. # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlip, BLIP->INSTRUCTBLIP
  428. class InstructBlipVisionModel(InstructBlipPreTrainedModel):
  429. main_input_name = "pixel_values"
  430. config_class = InstructBlipVisionConfig
  431. def __init__(self, config: InstructBlipVisionConfig):
  432. super().__init__(config)
  433. self.config = config
  434. embed_dim = config.hidden_size
  435. self.embeddings = InstructBlipVisionEmbeddings(config)
  436. self.encoder = InstructBlipEncoder(config)
  437. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  438. self.post_init()
  439. @add_start_docstrings_to_model_forward(INSTRUCTBLIP_VISION_INPUTS_DOCSTRING)
  440. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVisionConfig)
  441. def forward(
  442. self,
  443. pixel_values: Optional[torch.FloatTensor] = None,
  444. output_attentions: Optional[bool] = None,
  445. output_hidden_states: Optional[bool] = None,
  446. return_dict: Optional[bool] = None,
  447. interpolate_pos_encoding: bool = False,
  448. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  449. r"""
  450. Returns:
  451. """
  452. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  453. output_hidden_states = (
  454. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  455. )
  456. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  457. if pixel_values is None:
  458. raise ValueError("You have to specify pixel_values")
  459. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  460. encoder_outputs = self.encoder(
  461. inputs_embeds=hidden_states,
  462. output_attentions=output_attentions,
  463. output_hidden_states=output_hidden_states,
  464. return_dict=return_dict,
  465. )
  466. last_hidden_state = encoder_outputs[0]
  467. last_hidden_state = self.post_layernorm(last_hidden_state)
  468. pooled_output = last_hidden_state[:, 0, :]
  469. pooled_output = self.post_layernorm(pooled_output)
  470. if not return_dict:
  471. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  472. return BaseModelOutputWithPooling(
  473. last_hidden_state=last_hidden_state,
  474. pooler_output=pooled_output,
  475. hidden_states=encoder_outputs.hidden_states,
  476. attentions=encoder_outputs.attentions,
  477. )
  478. def get_input_embeddings(self):
  479. return self.embeddings
  480. class InstructBlipQFormerMultiHeadAttention(nn.Module):
  481. def __init__(self, config, is_cross_attention=False):
  482. super().__init__()
  483. self.config = config
  484. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  485. raise ValueError(
  486. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  487. % (config.hidden_size, config.num_attention_heads)
  488. )
  489. self.num_attention_heads = config.num_attention_heads
  490. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  491. self.all_head_size = self.num_attention_heads * self.attention_head_size
  492. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  493. if is_cross_attention:
  494. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  495. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  496. else:
  497. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  498. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  499. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  500. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  501. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  502. self.max_position_embeddings = config.max_position_embeddings
  503. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  504. self.save_attention = False
  505. def save_attn_gradients(self, attn_gradients):
  506. self.attn_gradients = attn_gradients
  507. def get_attn_gradients(self):
  508. return self.attn_gradients
  509. def save_attention_map(self, attention_map):
  510. self.attention_map = attention_map
  511. def get_attention_map(self):
  512. return self.attention_map
  513. def transpose_for_scores(self, x):
  514. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  515. x = x.view(*new_x_shape)
  516. return x.permute(0, 2, 1, 3)
  517. def forward(
  518. self,
  519. hidden_states,
  520. attention_mask=None,
  521. head_mask=None,
  522. encoder_hidden_states=None,
  523. encoder_attention_mask=None,
  524. past_key_value=None,
  525. output_attentions=False,
  526. ):
  527. # If this is instantiated as a cross-attention module, the keys
  528. # and values come from an encoder; the attention mask needs to be
  529. # such that the encoder's padding tokens are not attended to.
  530. is_cross_attention = encoder_hidden_states is not None
  531. if is_cross_attention:
  532. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  533. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  534. attention_mask = encoder_attention_mask
  535. elif past_key_value is not None:
  536. key_layer = self.transpose_for_scores(self.key(hidden_states))
  537. value_layer = self.transpose_for_scores(self.value(hidden_states))
  538. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  539. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  540. else:
  541. key_layer = self.transpose_for_scores(self.key(hidden_states))
  542. value_layer = self.transpose_for_scores(self.value(hidden_states))
  543. mixed_query_layer = self.query(hidden_states)
  544. query_layer = self.transpose_for_scores(mixed_query_layer)
  545. past_key_value = (key_layer, value_layer)
  546. # Take the dot product between "query" and "key" to get the raw attention scores.
  547. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  548. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  549. seq_length = hidden_states.size()[1]
  550. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  551. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  552. distance = position_ids_l - position_ids_r
  553. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  554. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  555. if self.position_embedding_type == "relative_key":
  556. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  557. attention_scores = attention_scores + relative_position_scores
  558. elif self.position_embedding_type == "relative_key_query":
  559. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  560. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  561. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  562. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  563. attention_scores_dtype = attention_scores.dtype
  564. if attention_mask is not None:
  565. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  566. attention_scores = attention_scores + attention_mask
  567. # Normalize the attention scores to probabilities.
  568. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  569. if is_cross_attention and self.save_attention:
  570. self.save_attention_map(attention_probs)
  571. attention_probs.register_hook(self.save_attn_gradients)
  572. # This is actually dropping out entire tokens to attend to, which might
  573. # seem a bit unusual, but is taken from the original Transformer paper.
  574. attention_probs_dropped = self.dropout(attention_probs)
  575. # Mask heads if we want to
  576. if head_mask is not None:
  577. attention_probs_dropped = attention_probs_dropped * head_mask
  578. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  579. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  580. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  581. context_layer = context_layer.view(*new_context_layer_shape)
  582. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  583. outputs = outputs + (past_key_value,)
  584. return outputs
  585. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
  586. class InstructBlipQFormerSelfOutput(nn.Module):
  587. def __init__(self, config):
  588. super().__init__()
  589. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  590. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  591. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  592. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  593. hidden_states = self.dense(hidden_states)
  594. hidden_states = self.dropout(hidden_states)
  595. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  596. return hidden_states
  597. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
  598. class InstructBlipQFormerAttention(nn.Module):
  599. def __init__(self, config, is_cross_attention=False):
  600. super().__init__()
  601. self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
  602. self.output = InstructBlipQFormerSelfOutput(config)
  603. self.pruned_heads = set()
  604. def prune_heads(self, heads):
  605. if len(heads) == 0:
  606. return
  607. heads, index = find_pruneable_heads_and_indices(
  608. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  609. )
  610. # Prune linear layers
  611. self.attention.query = prune_linear_layer(self.attention.query, index)
  612. self.attention.key = prune_linear_layer(self.attention.key, index)
  613. self.attention.value = prune_linear_layer(self.attention.value, index)
  614. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  615. # Update hyper params and store pruned heads
  616. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  617. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  618. self.pruned_heads = self.pruned_heads.union(heads)
  619. def forward(
  620. self,
  621. hidden_states: torch.Tensor,
  622. attention_mask: Optional[torch.FloatTensor] = None,
  623. head_mask: Optional[torch.FloatTensor] = None,
  624. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  625. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  626. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  627. output_attentions: Optional[bool] = False,
  628. ) -> Tuple[torch.Tensor]:
  629. self_outputs = self.attention(
  630. hidden_states,
  631. attention_mask,
  632. head_mask,
  633. encoder_hidden_states,
  634. encoder_attention_mask,
  635. past_key_value,
  636. output_attentions,
  637. )
  638. attention_output = self.output(self_outputs[0], hidden_states)
  639. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  640. return outputs
  641. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
  642. class InstructBlipQFormerIntermediate(nn.Module):
  643. def __init__(self, config):
  644. super().__init__()
  645. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  646. if isinstance(config.hidden_act, str):
  647. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  648. else:
  649. self.intermediate_act_fn = config.hidden_act
  650. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  651. hidden_states = self.dense(hidden_states)
  652. hidden_states = self.intermediate_act_fn(hidden_states)
  653. return hidden_states
  654. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
  655. class InstructBlipQFormerOutput(nn.Module):
  656. def __init__(self, config):
  657. super().__init__()
  658. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  659. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  660. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  661. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  662. hidden_states = self.dense(hidden_states)
  663. hidden_states = self.dropout(hidden_states)
  664. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  665. return hidden_states
  666. class InstructBlipQFormerLayer(nn.Module):
  667. def __init__(self, config, layer_idx):
  668. super().__init__()
  669. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  670. self.seq_len_dim = 1
  671. self.attention = InstructBlipQFormerAttention(config)
  672. self.layer_idx = layer_idx
  673. if layer_idx % config.cross_attention_frequency == 0:
  674. self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
  675. self.has_cross_attention = True
  676. else:
  677. self.has_cross_attention = False
  678. self.intermediate = InstructBlipQFormerIntermediate(config)
  679. self.output = InstructBlipQFormerOutput(config)
  680. self.intermediate_query = InstructBlipQFormerIntermediate(config)
  681. self.output_query = InstructBlipQFormerOutput(config)
  682. def forward(
  683. self,
  684. hidden_states,
  685. attention_mask=None,
  686. head_mask=None,
  687. encoder_hidden_states=None,
  688. encoder_attention_mask=None,
  689. past_key_value=None,
  690. output_attentions=False,
  691. query_length=0,
  692. ):
  693. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  694. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  695. self_attention_outputs = self.attention(
  696. hidden_states,
  697. attention_mask,
  698. head_mask,
  699. output_attentions=output_attentions,
  700. past_key_value=self_attn_past_key_value,
  701. )
  702. attention_output = self_attention_outputs[0]
  703. outputs = self_attention_outputs[1:-1]
  704. present_key_value = self_attention_outputs[-1]
  705. if query_length > 0:
  706. query_attention_output = attention_output[:, :query_length, :]
  707. if self.has_cross_attention:
  708. if encoder_hidden_states is None:
  709. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  710. cross_attention_outputs = self.crossattention(
  711. query_attention_output,
  712. attention_mask,
  713. head_mask,
  714. encoder_hidden_states,
  715. encoder_attention_mask,
  716. output_attentions=output_attentions,
  717. )
  718. query_attention_output = cross_attention_outputs[0]
  719. # add cross attentions if we output attention weights
  720. outputs = outputs + cross_attention_outputs[1:-1]
  721. layer_output = apply_chunking_to_forward(
  722. self.feed_forward_chunk_query,
  723. self.chunk_size_feed_forward,
  724. self.seq_len_dim,
  725. query_attention_output,
  726. )
  727. if attention_output.shape[1] > query_length:
  728. layer_output_text = apply_chunking_to_forward(
  729. self.feed_forward_chunk,
  730. self.chunk_size_feed_forward,
  731. self.seq_len_dim,
  732. attention_output[:, query_length:, :],
  733. )
  734. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  735. else:
  736. layer_output = apply_chunking_to_forward(
  737. self.feed_forward_chunk,
  738. self.chunk_size_feed_forward,
  739. self.seq_len_dim,
  740. attention_output,
  741. )
  742. outputs = (layer_output,) + outputs
  743. outputs = outputs + (present_key_value,)
  744. return outputs
  745. def feed_forward_chunk(self, attention_output):
  746. intermediate_output = self.intermediate(attention_output)
  747. layer_output = self.output(intermediate_output, attention_output)
  748. return layer_output
  749. def feed_forward_chunk_query(self, attention_output):
  750. intermediate_output = self.intermediate_query(attention_output)
  751. layer_output = self.output_query(intermediate_output, attention_output)
  752. return layer_output
  753. # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
  754. class InstructBlipQFormerEncoder(nn.Module):
  755. def __init__(self, config):
  756. super().__init__()
  757. self.config = config
  758. self.layer = nn.ModuleList(
  759. [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  760. )
  761. self.gradient_checkpointing = False
  762. def forward(
  763. self,
  764. hidden_states,
  765. attention_mask=None,
  766. head_mask=None,
  767. encoder_hidden_states=None,
  768. encoder_attention_mask=None,
  769. past_key_values=None,
  770. use_cache=None,
  771. output_attentions=False,
  772. output_hidden_states=False,
  773. return_dict=True,
  774. query_length=0,
  775. ):
  776. all_hidden_states = () if output_hidden_states else None
  777. all_self_attentions = () if output_attentions else None
  778. all_cross_attentions = () if output_attentions else None
  779. next_decoder_cache = () if use_cache else None
  780. for i in range(self.config.num_hidden_layers):
  781. layer_module = self.layer[i]
  782. if output_hidden_states:
  783. all_hidden_states = all_hidden_states + (hidden_states,)
  784. layer_head_mask = head_mask[i] if head_mask is not None else None
  785. past_key_value = past_key_values[i] if past_key_values is not None else None
  786. if getattr(self.config, "gradient_checkpointing", False) and self.training:
  787. if use_cache:
  788. logger.warning(
  789. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  790. )
  791. use_cache = False
  792. layer_outputs = self._gradient_checkpointing_func(
  793. layer_module.__call__,
  794. hidden_states,
  795. attention_mask,
  796. layer_head_mask,
  797. encoder_hidden_states,
  798. encoder_attention_mask,
  799. )
  800. else:
  801. layer_outputs = layer_module(
  802. hidden_states,
  803. attention_mask,
  804. layer_head_mask,
  805. encoder_hidden_states,
  806. encoder_attention_mask,
  807. past_key_value,
  808. output_attentions,
  809. query_length,
  810. )
  811. hidden_states = layer_outputs[0]
  812. if use_cache:
  813. next_decoder_cache += (layer_outputs[-1],)
  814. if output_attentions:
  815. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  816. if layer_module.has_cross_attention:
  817. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  818. if output_hidden_states:
  819. all_hidden_states = all_hidden_states + (hidden_states,)
  820. if not return_dict:
  821. return tuple(
  822. v
  823. for v in [
  824. hidden_states,
  825. next_decoder_cache,
  826. all_hidden_states,
  827. all_self_attentions,
  828. all_cross_attentions,
  829. ]
  830. if v is not None
  831. )
  832. return BaseModelOutputWithPastAndCrossAttentions(
  833. last_hidden_state=hidden_states,
  834. past_key_values=next_decoder_cache,
  835. hidden_states=all_hidden_states,
  836. attentions=all_self_attentions,
  837. cross_attentions=all_cross_attentions,
  838. )
  839. class InstructBlipQFormerEmbeddings(nn.Module):
  840. """Construct the embeddings from word and position embeddings."""
  841. def __init__(self, config):
  842. super().__init__()
  843. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  844. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  845. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  846. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  847. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  848. self.register_buffer(
  849. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  850. )
  851. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  852. self.config = config
  853. def forward(
  854. self,
  855. input_ids=None,
  856. position_ids=None,
  857. query_embeds=None,
  858. past_key_values_length=0,
  859. ):
  860. if input_ids is not None:
  861. seq_length = input_ids.size()[1]
  862. else:
  863. seq_length = 0
  864. if position_ids is None:
  865. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  866. if input_ids is not None:
  867. embeddings = self.word_embeddings(input_ids)
  868. if self.position_embedding_type == "absolute":
  869. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  870. embeddings = embeddings + position_embeddings
  871. if query_embeds is not None:
  872. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  873. else:
  874. embeddings = query_embeds
  875. embeddings = embeddings.to(self.layernorm.weight.dtype)
  876. embeddings = self.layernorm(embeddings)
  877. embeddings = self.dropout(embeddings)
  878. return embeddings
  879. class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
  880. """
  881. Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
  882. instruction as input.
  883. """
  884. def __init__(self, config: InstructBlipQFormerConfig):
  885. super().__init__(config)
  886. self.config = config
  887. self.embeddings = InstructBlipQFormerEmbeddings(config)
  888. self.encoder = InstructBlipQFormerEncoder(config)
  889. self.post_init()
  890. def get_input_embeddings(self):
  891. return self.embeddings.word_embeddings
  892. def set_input_embeddings(self, value):
  893. self.embeddings.word_embeddings = value
  894. def _prune_heads(self, heads_to_prune):
  895. """
  896. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  897. class PreTrainedModel
  898. """
  899. for layer, heads in heads_to_prune.items():
  900. self.encoder.layer[layer].attention.prune_heads(heads)
  901. def get_extended_attention_mask(
  902. self,
  903. attention_mask: torch.Tensor,
  904. input_shape: Tuple[int],
  905. device: torch.device,
  906. has_query: bool = False,
  907. ) -> torch.Tensor:
  908. """
  909. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  910. Arguments:
  911. attention_mask (`torch.Tensor`):
  912. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  913. input_shape (`Tuple[int]`):
  914. The shape of the input to the model.
  915. device: (`torch.device`):
  916. The device of the input to the model.
  917. Returns:
  918. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  919. """
  920. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  921. # ourselves in which case we just need to make it broadcastable to all heads.
  922. if attention_mask.dim() == 3:
  923. extended_attention_mask = attention_mask[:, None, :, :]
  924. elif attention_mask.dim() == 2:
  925. # Provided a padding mask of dimensions [batch_size, seq_length]
  926. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  927. extended_attention_mask = attention_mask[:, None, None, :]
  928. else:
  929. raise ValueError(
  930. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  931. )
  932. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  933. # masked positions, this operation will create a tensor which is 0.0 for
  934. # positions we want to attend and -10000.0 for masked positions.
  935. # Since we are adding it to the raw scores before the softmax, this is
  936. # effectively the same as removing these entirely.
  937. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  938. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  939. return extended_attention_mask
  940. def forward(
  941. self,
  942. input_ids: torch.LongTensor,
  943. attention_mask: Optional[torch.FloatTensor] = None,
  944. position_ids: Optional[torch.LongTensor] = None,
  945. query_embeds: Optional[torch.Tensor] = None,
  946. head_mask: Optional[torch.FloatTensor] = None,
  947. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  948. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  949. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  950. use_cache: Optional[bool] = None,
  951. output_attentions: Optional[bool] = None,
  952. output_hidden_states: Optional[bool] = None,
  953. return_dict: Optional[bool] = None,
  954. ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  955. r"""
  956. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  957. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  958. the model is configured as a decoder.
  959. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  960. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  961. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  962. - 1 for tokens that are **not masked**,
  963. - 0 for tokens that are **masked**.
  964. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
  965. shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
  966. value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
  967. used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
  968. value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
  969. `(batch_size, sequence_length)`.
  970. use_cache (`bool`, *optional*):
  971. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  972. `past_key_values`).
  973. """
  974. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  975. output_hidden_states = (
  976. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  977. )
  978. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  979. if input_ids is None and query_embeds is None:
  980. raise ValueError("You have to specify query_embeds when input_ids is None")
  981. # past_key_values_length
  982. past_key_values_length = (
  983. past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
  984. )
  985. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  986. embedding_output = self.embeddings(
  987. input_ids=input_ids,
  988. position_ids=position_ids,
  989. query_embeds=query_embeds,
  990. past_key_values_length=past_key_values_length,
  991. )
  992. input_shape = embedding_output.size()[:-1]
  993. batch_size, seq_length = input_shape
  994. device = embedding_output.device
  995. if attention_mask is None:
  996. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  997. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  998. # ourselves in which case we just need to make it broadcastable to all heads.
  999. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  1000. # If a 2D or 3D attention mask is provided for the cross-attention
  1001. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1002. if encoder_hidden_states is not None:
  1003. if isinstance(encoder_hidden_states, list):
  1004. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  1005. else:
  1006. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1007. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1008. if isinstance(encoder_attention_mask, list):
  1009. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  1010. elif encoder_attention_mask is None:
  1011. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  1012. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1013. else:
  1014. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1015. else:
  1016. encoder_extended_attention_mask = None
  1017. # Prepare head mask if needed
  1018. # 1.0 in head_mask indicate we keep the head
  1019. # attention_probs has shape bsz x n_heads x N x N
  1020. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  1021. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1022. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  1023. encoder_outputs = self.encoder(
  1024. embedding_output,
  1025. attention_mask=extended_attention_mask,
  1026. head_mask=head_mask,
  1027. encoder_hidden_states=encoder_hidden_states,
  1028. encoder_attention_mask=encoder_extended_attention_mask,
  1029. past_key_values=past_key_values,
  1030. use_cache=use_cache,
  1031. output_attentions=output_attentions,
  1032. output_hidden_states=output_hidden_states,
  1033. return_dict=return_dict,
  1034. query_length=query_length,
  1035. )
  1036. sequence_output = encoder_outputs[0]
  1037. pooled_output = sequence_output[:, 0, :]
  1038. if not return_dict:
  1039. return (sequence_output, pooled_output) + encoder_outputs[1:]
  1040. return BaseModelOutputWithPoolingAndCrossAttentions(
  1041. last_hidden_state=sequence_output,
  1042. pooler_output=pooled_output,
  1043. past_key_values=encoder_outputs.past_key_values,
  1044. hidden_states=encoder_outputs.hidden_states,
  1045. attentions=encoder_outputs.attentions,
  1046. cross_attentions=encoder_outputs.cross_attentions,
  1047. )
  1048. @add_start_docstrings(
  1049. """
  1050. InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
  1051. encoder, Querying Transformer (Q-Former) and a language model.
  1052. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  1053. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  1054. """,
  1055. INSTRUCTBLIP_START_DOCSTRING,
  1056. )
  1057. class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
  1058. config_class = InstructBlipConfig
  1059. main_input_name = "pixel_values"
  1060. def __init__(self, config: InstructBlipConfig):
  1061. super().__init__(config)
  1062. self.vision_model = InstructBlipVisionModel(config.vision_config)
  1063. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  1064. self.qformer = InstructBlipQFormerModel(config.qformer_config)
  1065. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  1066. if config.use_decoder_only_language_model:
  1067. language_model = AutoModelForCausalLM.from_config(config.text_config)
  1068. else:
  1069. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  1070. if language_model._no_split_modules is not None:
  1071. self._no_split_modules.extend(language_model._no_split_modules)
  1072. if language_model._keep_in_fp32_modules is not None:
  1073. self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
  1074. self.language_model = language_model
  1075. # Initialize weights and apply final processing
  1076. self.post_init()
  1077. def get_input_embeddings(self):
  1078. return self.language_model.get_input_embeddings()
  1079. def set_input_embeddings(self, value):
  1080. self.language_model.set_input_embeddings(value)
  1081. def set_output_embeddings(self, new_embeddings):
  1082. self.language_model.set_output_embeddings(new_embeddings)
  1083. def get_output_embeddings(self) -> nn.Module:
  1084. return self.language_model.get_output_embeddings()
  1085. def get_encoder(self):
  1086. return self.language_model.get_encoder()
  1087. def get_decoder(self):
  1088. return self.language_model.get_decoder()
  1089. def _tie_weights(self):
  1090. if not self.config.use_decoder_only_language_model:
  1091. self.language_model.encoder.embed_tokens = self.language_model.shared
  1092. self.language_model.decoder.embed_tokens = self.language_model.shared
  1093. def _preprocess_accelerate(self):
  1094. r"""
  1095. Some pre-processing hacks to make the model `accelerate` compatible. Check
  1096. https://github.com/huggingface/transformers/pull/21707 for more details.
  1097. """
  1098. hf_device_map = self.hf_device_map
  1099. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  1100. # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
  1101. logger.warning(
  1102. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  1103. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  1104. " Please pass a `device_map` that contains `language_model` to remove this warning."
  1105. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  1106. " more details on creating a `device_map` for large models.",
  1107. )
  1108. if hasattr(self.language_model, "_hf_hook"):
  1109. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  1110. @add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
  1111. @replace_return_docstrings(
  1112. output_type=InstructBlipForConditionalGenerationModelOutput, config_class=InstructBlipVisionConfig
  1113. )
  1114. def forward(
  1115. self,
  1116. pixel_values: torch.FloatTensor,
  1117. qformer_input_ids: torch.FloatTensor,
  1118. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1119. input_ids: Optional[torch.FloatTensor] = None,
  1120. attention_mask: Optional[torch.LongTensor] = None,
  1121. decoder_input_ids: Optional[torch.LongTensor] = None,
  1122. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1123. output_attentions: Optional[bool] = None,
  1124. output_hidden_states: Optional[bool] = None,
  1125. labels: Optional[torch.LongTensor] = None,
  1126. return_dict: Optional[bool] = None,
  1127. interpolate_pos_encoding: bool = False,
  1128. ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
  1129. r"""
  1130. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1131. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1132. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1133. config.vocab_size]`
  1134. Returns:
  1135. Examples:
  1136. ```python
  1137. >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
  1138. >>> import torch
  1139. >>> from PIL import Image
  1140. >>> import requests
  1141. >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1142. >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1143. >>> device = "cuda" if torch.cuda.is_available() else "cpu"
  1144. >>> model.to(device) # doctest: +IGNORE_RESULT
  1145. >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
  1146. >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
  1147. >>> prompt = "What is unusual about this image?"
  1148. >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
  1149. >>> outputs = model.generate(
  1150. ... **inputs,
  1151. ... do_sample=False,
  1152. ... num_beams=5,
  1153. ... max_length=256,
  1154. ... min_length=1,
  1155. ... top_p=0.9,
  1156. ... repetition_penalty=1.5,
  1157. ... length_penalty=1.0,
  1158. ... temperature=1,
  1159. ... )
  1160. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1161. >>> print(generated_text)
  1162. The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
  1163. ```"""
  1164. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1165. # step 1: forward the images through the vision encoder,
  1166. # to get image embeddings of shape (batch_size, seq_len, hidden_size)
  1167. vision_outputs = self.vision_model(
  1168. pixel_values=pixel_values,
  1169. output_attentions=output_attentions,
  1170. output_hidden_states=output_hidden_states,
  1171. return_dict=return_dict,
  1172. interpolate_pos_encoding=interpolate_pos_encoding,
  1173. )
  1174. image_embeds = vision_outputs[0]
  1175. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  1176. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1177. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1178. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1179. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1180. if qformer_attention_mask is None:
  1181. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1182. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1183. query_outputs = self.qformer(
  1184. input_ids=qformer_input_ids,
  1185. attention_mask=qformer_attention_mask,
  1186. query_embeds=query_tokens,
  1187. encoder_hidden_states=image_embeds,
  1188. encoder_attention_mask=image_attention_mask,
  1189. output_attentions=output_attentions,
  1190. output_hidden_states=output_hidden_states,
  1191. return_dict=return_dict,
  1192. )
  1193. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  1194. # step 3: use the language model, conditioned on the query outputs and the prompt
  1195. language_model_inputs = self.language_projection(query_output)
  1196. language_model_attention_mask = torch.ones(
  1197. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  1198. )
  1199. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  1200. if attention_mask is None:
  1201. attention_mask = torch.ones_like(input_ids)
  1202. # if the model already has "image_token_index" then the input is expanded to account for image embeds
  1203. # otherwise we expand manually by concatenating
  1204. if getattr(self.config, "image_token_index", None) is not None:
  1205. special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  1206. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  1207. else:
  1208. logger.warning_once(
  1209. "Expanding inputs for image tokens in InstructBLIP should be done in processing. "
  1210. "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
  1211. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  1212. )
  1213. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  1214. attention_mask = torch.cat(
  1215. [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
  1216. )
  1217. if self.config.use_decoder_only_language_model:
  1218. outputs = self.language_model(
  1219. inputs_embeds=inputs_embeds,
  1220. attention_mask=attention_mask,
  1221. output_attentions=output_attentions,
  1222. output_hidden_states=output_hidden_states,
  1223. return_dict=return_dict,
  1224. )
  1225. logits = outputs.logits if return_dict else outputs[0]
  1226. loss = None
  1227. # we compute the loss here since we need to take into account the sequence length of the query embeds
  1228. if labels is not None:
  1229. labels = labels.to(logits.device)
  1230. logits = logits[:, -labels.size(1) :, :]
  1231. # Shift so that tokens < n predict n
  1232. shift_logits = logits[..., :-1, :].contiguous()
  1233. shift_labels = labels[..., 1:].contiguous().to(logits.device)
  1234. # Flatten the tokens
  1235. loss_fct = CrossEntropyLoss(reduction="mean")
  1236. loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
  1237. else:
  1238. outputs = self.language_model(
  1239. inputs_embeds=inputs_embeds,
  1240. attention_mask=attention_mask,
  1241. decoder_input_ids=decoder_input_ids,
  1242. decoder_attention_mask=decoder_attention_mask,
  1243. output_attentions=output_attentions,
  1244. output_hidden_states=output_hidden_states,
  1245. return_dict=return_dict,
  1246. labels=labels,
  1247. )
  1248. loss = outputs.loss if return_dict else outputs[0]
  1249. logits = outputs.logits if return_dict else outputs[1]
  1250. if not return_dict:
  1251. output = (logits, vision_outputs, query_outputs, outputs)
  1252. return ((loss,) + output) if loss is not None else output
  1253. return InstructBlipForConditionalGenerationModelOutput(
  1254. loss=loss,
  1255. logits=logits,
  1256. vision_outputs=vision_outputs,
  1257. qformer_outputs=query_outputs,
  1258. language_model_outputs=outputs,
  1259. )
  1260. @torch.no_grad()
  1261. def generate(
  1262. self,
  1263. pixel_values: torch.FloatTensor,
  1264. qformer_input_ids: Optional[torch.LongTensor] = None,
  1265. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1266. input_ids: Optional[torch.LongTensor] = None,
  1267. attention_mask: Optional[torch.LongTensor] = None,
  1268. interpolate_pos_encoding: bool = False,
  1269. **generate_kwargs,
  1270. ) -> torch.LongTensor:
  1271. """
  1272. Overrides `generate` function to be able to use the model as a conditional generator.
  1273. Args:
  1274. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
  1275. Input images to be processed.
  1276. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1277. The sequence used as a prompt to be fed to the Q-Former module.
  1278. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1279. Mask to avoid performing attention on padding token indices.
  1280. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1281. The sequence used as a prompt for the generation.
  1282. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1283. Mask to avoid performing attention on padding token indices.
  1284. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1285. Whether to interpolate the positional encoding of the image embeddings.
  1286. Returns:
  1287. captions (list): A list of strings of length batch_size * num_captions.
  1288. """
  1289. if hasattr(self, "hf_device_map"):
  1290. # preprocess for `accelerate`
  1291. self._preprocess_accelerate()
  1292. batch_size = pixel_values.shape[0]
  1293. image_embeds = self.vision_model(
  1294. pixel_values,
  1295. return_dict=True,
  1296. interpolate_pos_encoding=interpolate_pos_encoding,
  1297. ).last_hidden_state
  1298. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1299. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1300. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1301. if qformer_attention_mask is None:
  1302. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1303. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1304. query_outputs = self.qformer(
  1305. input_ids=qformer_input_ids,
  1306. attention_mask=qformer_attention_mask,
  1307. query_embeds=query_tokens,
  1308. encoder_hidden_states=image_embeds,
  1309. encoder_attention_mask=image_attention_mask,
  1310. return_dict=True,
  1311. )
  1312. query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
  1313. language_model_inputs = self.language_projection(query_output)
  1314. language_attention_mask = torch.ones(
  1315. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  1316. )
  1317. if input_ids is None:
  1318. input_ids = (
  1319. torch.LongTensor([[self.config.text_config.bos_token_id]])
  1320. .repeat(batch_size, 1)
  1321. .to(image_embeds.device)
  1322. )
  1323. if attention_mask is None:
  1324. attention_mask = torch.ones_like(input_ids)
  1325. inputs_embeds = self.get_input_embeddings()(input_ids)
  1326. # if the model already has "image_token_index" then the input is expanded to account for image embeds
  1327. # otherwise we expand manually by concatenating
  1328. if getattr(self.config, "image_token_index", None) is not None:
  1329. special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  1330. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  1331. else:
  1332. logger.warning_once(
  1333. "Expanding inputs for image tokens in InstructBLIP should be done in processing. "
  1334. "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your InstructBLIP model. "
  1335. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  1336. )
  1337. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  1338. attention_mask = torch.cat(
  1339. [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
  1340. )
  1341. # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
  1342. # -1 is to account for the prepended BOS after `generate.`
  1343. if not self.language_model.config.is_encoder_decoder:
  1344. generate_kwargs["max_length"] = (
  1345. generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
  1346. )
  1347. generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
  1348. outputs = self.language_model.generate(
  1349. inputs_embeds=inputs_embeds,
  1350. attention_mask=attention_mask,
  1351. **generate_kwargs,
  1352. )
  1353. # this is a temporary workaround to be consistent with other generation models and
  1354. # have BOS as the first token, even though under the hood we are calling LM with embeds
  1355. if not self.language_model.config.is_encoder_decoder:
  1356. # the InstructBLIP authors used inconsistent tokenizer/model files during training,
  1357. # with the tokenizer's bos token being set to </s> which has ID=2,
  1358. # whereas the model's text config has bos token id = 0
  1359. bos_token_id = (
  1360. 2
  1361. if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
  1362. else self.config.text_config.bos_token_id
  1363. )
  1364. bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
  1365. if not isinstance(outputs, torch.Tensor):
  1366. outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
  1367. else:
  1368. outputs = torch.cat([bos_tokens, outputs], dim=-1)
  1369. return outputs