modeling_instructblipvideo.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_instructblipvideo.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Any, Optional, Tuple, Union
  24. import torch
  25. import torch.utils.checkpoint
  26. from torch import nn
  27. from torch.nn import CrossEntropyLoss
  28. from ...activations import ACT2FN
  29. from ...generation import GenerationMixin
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. BaseModelOutputWithPooling,
  34. BaseModelOutputWithPoolingAndCrossAttentions,
  35. )
  36. from ...modeling_utils import PreTrainedModel
  37. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  38. from ...utils import (
  39. ModelOutput,
  40. add_start_docstrings,
  41. add_start_docstrings_to_model_forward,
  42. logging,
  43. replace_return_docstrings,
  44. torch_int,
  45. )
  46. from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
  47. from .configuration_instructblipvideo import (
  48. InstructBlipVideoConfig,
  49. InstructBlipVideoQFormerConfig,
  50. InstructBlipVideoVisionConfig,
  51. )
  52. logger = logging.get_logger(__name__)
  53. @dataclass
  54. class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
  55. """
  56. Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
  57. Args:
  58. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  59. Language modeling loss from the language model.
  60. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  61. Prediction scores of the language modeling head of the language model.
  62. vision_outputs (`BaseModelOutputWithPooling`):
  63. Outputs of the vision encoder.
  64. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
  65. Outputs of the Q-Former (Querying Transformer).
  66. language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
  67. Outputs of the language model.
  68. """
  69. loss: Optional[Tuple[torch.FloatTensor]] = None
  70. logits: Optional[Tuple[torch.FloatTensor]] = None
  71. vision_outputs: Optional[torch.FloatTensor] = None
  72. qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
  73. language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
  74. def to_tuple(self) -> Tuple[Any]:
  75. return tuple(
  76. self[k]
  77. if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
  78. else getattr(self, k).to_tuple()
  79. for k in self.keys()
  80. )
  81. class InstructBlipVideoVisionEmbeddings(nn.Module):
  82. def __init__(self, config: InstructBlipVideoVisionConfig):
  83. super().__init__()
  84. self.config = config
  85. self.embed_dim = config.hidden_size
  86. self.image_size = config.image_size
  87. self.patch_size = config.patch_size
  88. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  89. self.patch_embedding = nn.Conv2d(
  90. in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
  91. )
  92. self.num_patches = (self.image_size // self.patch_size) ** 2
  93. self.num_positions = self.num_patches + 1
  94. self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
  95. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  96. """
  97. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  98. images. This method is also adapted to support torch.jit tracing.
  99. Adapted from:
  100. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  101. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  102. """
  103. num_patches = embeddings.shape[1] - 1
  104. num_positions = self.position_embedding.shape[1] - 1
  105. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  106. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  107. return self.position_embedding
  108. class_pos_embed = self.position_embedding[:, :1]
  109. patch_pos_embed = self.position_embedding[:, 1:]
  110. dim = embeddings.shape[-1]
  111. new_height = height // self.patch_size
  112. new_width = width // self.patch_size
  113. sqrt_num_positions = torch_int(num_positions**0.5)
  114. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  115. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  116. patch_pos_embed = nn.functional.interpolate(
  117. patch_pos_embed,
  118. size=(new_height, new_width),
  119. mode="bicubic",
  120. align_corners=False,
  121. )
  122. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  123. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  124. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  125. batch_size, _, height, width = pixel_values.shape
  126. target_dtype = self.patch_embedding.weight.dtype
  127. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  128. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  129. class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
  130. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  131. if interpolate_pos_encoding:
  132. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  133. else:
  134. position_embedding = self.position_embedding
  135. embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
  136. return embeddings
  137. class InstructBlipVideoAttention(nn.Module):
  138. """Multi-headed attention from 'Attention Is All You Need' paper"""
  139. def __init__(self, config):
  140. super().__init__()
  141. self.config = config
  142. self.embed_dim = config.hidden_size
  143. self.num_heads = config.num_attention_heads
  144. self.head_dim = self.embed_dim // self.num_heads
  145. if self.head_dim * self.num_heads != self.embed_dim:
  146. raise ValueError(
  147. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  148. f" {self.num_heads})."
  149. )
  150. self.scale = self.head_dim**-0.5
  151. self.dropout = nn.Dropout(config.attention_dropout)
  152. # small tweak here compared to CLIP, no bias here
  153. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  154. if config.qkv_bias:
  155. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  156. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  157. else:
  158. q_bias = None
  159. v_bias = None
  160. if q_bias is not None:
  161. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  162. self.qkv.bias = nn.Parameter(qkv_bias)
  163. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  164. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  165. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. head_mask: Optional[torch.Tensor] = None,
  170. output_attentions: Optional[bool] = False,
  171. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  172. """Input shape: Batch x Time x Channel"""
  173. bsz, tgt_len, embed_dim = hidden_states.size()
  174. mixed_qkv = self.qkv(hidden_states)
  175. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  176. 2, 0, 3, 1, 4
  177. )
  178. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  179. # Take the dot product between "query" and "key" to get the raw attention scores.
  180. attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
  181. attention_scores = attention_scores * self.scale
  182. # Normalize the attention scores to probabilities.
  183. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  184. # This is actually dropping out entire tokens to attend to, which might
  185. # seem a bit unusual, but is taken from the original Transformer paper.
  186. attention_probs = self.dropout(attention_probs)
  187. # Mask heads if we want to
  188. if head_mask is not None:
  189. attention_probs = attention_probs * head_mask
  190. context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
  191. new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
  192. context_layer = context_layer.reshape(new_context_layer_shape)
  193. output = self.projection(context_layer)
  194. outputs = (output, attention_probs) if output_attentions else (output, None)
  195. return outputs
  196. class InstructBlipVideoMLP(nn.Module):
  197. def __init__(self, config):
  198. super().__init__()
  199. self.config = config
  200. self.activation_fn = ACT2FN[config.hidden_act]
  201. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  202. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  203. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  204. hidden_states = self.fc1(hidden_states)
  205. hidden_states = self.activation_fn(hidden_states)
  206. hidden_states = self.fc2(hidden_states)
  207. return hidden_states
  208. class InstructBlipVideoEncoderLayer(nn.Module):
  209. def __init__(self, config: InstructBlipVideoConfig):
  210. super().__init__()
  211. self.embed_dim = config.hidden_size
  212. self.self_attn = InstructBlipVideoAttention(config)
  213. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  214. self.mlp = InstructBlipVideoMLP(config)
  215. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  216. def forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. attention_mask: torch.Tensor,
  220. output_attentions: Optional[bool] = False,
  221. ) -> Tuple[torch.FloatTensor]:
  222. """
  223. Args:
  224. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  225. attention_mask (`torch.FloatTensor`): attention mask of size
  226. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  227. `(config.encoder_attention_heads,)`.
  228. output_attentions (`bool`, *optional*):
  229. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  230. returned tensors for more detail.
  231. """
  232. residual = hidden_states
  233. hidden_states = self.layer_norm1(hidden_states)
  234. hidden_states, attn_weights = self.self_attn(
  235. hidden_states=hidden_states,
  236. head_mask=attention_mask,
  237. output_attentions=output_attentions,
  238. )
  239. hidden_states = hidden_states + residual
  240. residual = hidden_states
  241. hidden_states = self.layer_norm2(hidden_states)
  242. hidden_states = self.mlp(hidden_states)
  243. hidden_states = hidden_states + residual
  244. outputs = (hidden_states,)
  245. if output_attentions:
  246. outputs += (attn_weights,)
  247. return outputs
  248. class InstructBlipVideoPreTrainedModel(PreTrainedModel):
  249. """
  250. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  251. models.
  252. """
  253. config_class = InstructBlipVideoConfig
  254. base_model_prefix = "blip"
  255. supports_gradient_checkpointing = True
  256. _no_split_modules = [
  257. "InstructBlipVideoQFormerEmbeddings",
  258. "InstructBlipVideoAttention",
  259. "InstructBlipVideoQFormerMultiHeadAttention",
  260. "InstructBlipVideoQFormerSelfOutput",
  261. ]
  262. _keep_in_fp32_modules = []
  263. def _init_weights(self, module):
  264. """Initialize the weights"""
  265. factor = self.config.initializer_range
  266. if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
  267. module.weight.data.normal_(mean=0.0, std=factor)
  268. if hasattr(module, "bias") and module.bias is not None:
  269. module.bias.data.zero_()
  270. if isinstance(module, InstructBlipVideoVisionEmbeddings):
  271. if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVideoVisionConfig):
  272. factor = self.config.vision_config.initializer_range
  273. nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
  274. nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
  275. elif isinstance(module, nn.LayerNorm):
  276. module.bias.data.zero_()
  277. module.weight.data.fill_(1.0)
  278. elif isinstance(module, nn.Linear) and module.bias is not None:
  279. module.bias.data.zero_()
  280. INSTRUCTBLIPVIDEO_START_DOCSTRING = r"""
  281. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  282. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  283. etc.)
  284. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  285. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  286. and behavior.
  287. Parameters:
  288. config ([`InstructBlipVideoConfig`]): Model configuration class with all the parameters of the model.
  289. Initializing with a config file does not load the weights associated with the model, only the
  290. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  291. """
  292. INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING = r"""
  293. Args:
  294. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  295. Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
  296. [`InstructBlipVideoProcessor.__call__`] for details.
  297. output_attentions (`bool`, *optional*):
  298. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  299. tensors for more detail.
  300. output_hidden_states (`bool`, *optional*):
  301. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  302. more detail.
  303. return_dict (`bool`, *optional*):
  304. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  305. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  306. Whether to interpolate the pre-trained position encodings.
  307. """
  308. INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
  309. Args:
  310. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  311. Pixel values. Pixel values can be obtained using [`InstructBlipVideoProcessor`]. See
  312. [`InstructBlipVideoProcessor.__call__`] for details.
  313. qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  314. Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
  315. to serve as text prompt, which the Q-Former model will encode.
  316. Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
  317. details.
  318. [What are input IDs?](../glossary#input-ids)
  319. qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  320. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  321. - 1 for tokens that are **not masked**,
  322. - 0 for tokens that are **masked**.
  323. [What are attention masks?](../glossary#attention-mask)
  324. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  325. Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
  326. provided to serve as text prompt, which the language model can continue.
  327. Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
  328. details.
  329. [What are input IDs?](../glossary#input-ids)
  330. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  331. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  332. - 1 for tokens that are **not masked**,
  333. - 0 for tokens that are **masked**.
  334. [What are attention masks?](../glossary#attention-mask)
  335. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  336. Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
  337. encoder-decoder language model (like T5) is used.
  338. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  339. [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
  340. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  341. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  342. be used by default.
  343. Only relevant in case an encoder-decoder language model (like T5) is used.
  344. output_attentions (`bool`, *optional*):
  345. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  346. tensors for more detail.
  347. output_hidden_states (`bool`, *optional*):
  348. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  349. more detail.
  350. return_dict (`bool`, *optional*):
  351. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  352. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  353. Whether to interpolate the pre-trained position encodings.
  354. """
  355. class InstructBlipVideoEncoder(nn.Module):
  356. """
  357. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  358. [`InstructBlipVideoEncoderLayer`].
  359. Args:
  360. config (`InstructBlipVideoConfig`):
  361. The corresponding vision configuration for the `InstructBlipVideoEncoder`.
  362. """
  363. def __init__(self, config: InstructBlipVideoConfig):
  364. super().__init__()
  365. self.config = config
  366. self.layers = nn.ModuleList([InstructBlipVideoEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  367. self.gradient_checkpointing = False
  368. def forward(
  369. self,
  370. inputs_embeds,
  371. attention_mask: Optional[torch.Tensor] = None,
  372. output_attentions: Optional[bool] = None,
  373. output_hidden_states: Optional[bool] = None,
  374. return_dict: Optional[bool] = None,
  375. ) -> Union[Tuple, BaseModelOutput]:
  376. r"""
  377. Args:
  378. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  379. Embedded representation of the inputs. Should be float, not int tokens.
  380. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  381. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  382. - 1 for tokens that are **not masked**,
  383. - 0 for tokens that are **masked**.
  384. [What are attention masks?](../glossary#attention-mask)
  385. output_attentions (`bool`, *optional*):
  386. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  387. returned tensors for more detail.
  388. output_hidden_states (`bool`, *optional*):
  389. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  390. for more detail.
  391. return_dict (`bool`, *optional*):
  392. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  393. """
  394. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  395. output_hidden_states = (
  396. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  397. )
  398. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  399. encoder_states = () if output_hidden_states else None
  400. all_attentions = () if output_attentions else None
  401. hidden_states = inputs_embeds
  402. for idx, encoder_layer in enumerate(self.layers):
  403. if output_hidden_states:
  404. encoder_states = encoder_states + (hidden_states,)
  405. if self.gradient_checkpointing and self.training:
  406. layer_outputs = self._gradient_checkpointing_func(
  407. encoder_layer.__call__,
  408. hidden_states,
  409. attention_mask,
  410. output_attentions,
  411. )
  412. else:
  413. layer_outputs = encoder_layer(
  414. hidden_states,
  415. attention_mask,
  416. output_attentions=output_attentions,
  417. )
  418. hidden_states = layer_outputs[0]
  419. if output_attentions:
  420. all_attentions = all_attentions + (layer_outputs[1],)
  421. if output_hidden_states:
  422. encoder_states = encoder_states + (hidden_states,)
  423. if not return_dict:
  424. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  425. return BaseModelOutput(
  426. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  427. )
  428. class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
  429. main_input_name = "pixel_values"
  430. config_class = InstructBlipVideoVisionConfig
  431. def __init__(self, config: InstructBlipVideoVisionConfig):
  432. super().__init__(config)
  433. self.config = config
  434. embed_dim = config.hidden_size
  435. self.embeddings = InstructBlipVideoVisionEmbeddings(config)
  436. self.encoder = InstructBlipVideoEncoder(config)
  437. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  438. self.post_init()
  439. @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_VISION_INPUTS_DOCSTRING)
  440. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=InstructBlipVideoVisionConfig)
  441. def forward(
  442. self,
  443. pixel_values: Optional[torch.FloatTensor] = None,
  444. output_attentions: Optional[bool] = None,
  445. output_hidden_states: Optional[bool] = None,
  446. return_dict: Optional[bool] = None,
  447. interpolate_pos_encoding: bool = False,
  448. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  449. r"""
  450. Returns:
  451. """
  452. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  453. output_hidden_states = (
  454. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  455. )
  456. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  457. if pixel_values is None:
  458. raise ValueError("You have to specify pixel_values")
  459. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  460. encoder_outputs = self.encoder(
  461. inputs_embeds=hidden_states,
  462. output_attentions=output_attentions,
  463. output_hidden_states=output_hidden_states,
  464. return_dict=return_dict,
  465. )
  466. last_hidden_state = encoder_outputs[0]
  467. last_hidden_state = self.post_layernorm(last_hidden_state)
  468. pooled_output = last_hidden_state[:, 0, :]
  469. pooled_output = self.post_layernorm(pooled_output)
  470. if not return_dict:
  471. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  472. return BaseModelOutputWithPooling(
  473. last_hidden_state=last_hidden_state,
  474. pooler_output=pooled_output,
  475. hidden_states=encoder_outputs.hidden_states,
  476. attentions=encoder_outputs.attentions,
  477. )
  478. def get_input_embeddings(self):
  479. return self.embeddings
  480. class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
  481. def __init__(self, config, is_cross_attention=False):
  482. super().__init__()
  483. self.config = config
  484. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  485. raise ValueError(
  486. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  487. % (config.hidden_size, config.num_attention_heads)
  488. )
  489. self.num_attention_heads = config.num_attention_heads
  490. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  491. self.all_head_size = self.num_attention_heads * self.attention_head_size
  492. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  493. if is_cross_attention:
  494. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  495. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  496. else:
  497. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  498. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  499. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  500. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  501. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  502. self.max_position_embeddings = config.max_position_embeddings
  503. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  504. self.save_attention = False
  505. def save_attn_gradients(self, attn_gradients):
  506. self.attn_gradients = attn_gradients
  507. def get_attn_gradients(self):
  508. return self.attn_gradients
  509. def save_attention_map(self, attention_map):
  510. self.attention_map = attention_map
  511. def get_attention_map(self):
  512. return self.attention_map
  513. def transpose_for_scores(self, x):
  514. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  515. x = x.view(*new_x_shape)
  516. return x.permute(0, 2, 1, 3)
  517. def forward(
  518. self,
  519. hidden_states,
  520. attention_mask=None,
  521. head_mask=None,
  522. encoder_hidden_states=None,
  523. encoder_attention_mask=None,
  524. past_key_value=None,
  525. output_attentions=False,
  526. ):
  527. # If this is instantiated as a cross-attention module, the keys
  528. # and values come from an encoder; the attention mask needs to be
  529. # such that the encoder's padding tokens are not attended to.
  530. is_cross_attention = encoder_hidden_states is not None
  531. if is_cross_attention:
  532. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  533. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  534. attention_mask = encoder_attention_mask
  535. elif past_key_value is not None:
  536. key_layer = self.transpose_for_scores(self.key(hidden_states))
  537. value_layer = self.transpose_for_scores(self.value(hidden_states))
  538. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  539. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  540. else:
  541. key_layer = self.transpose_for_scores(self.key(hidden_states))
  542. value_layer = self.transpose_for_scores(self.value(hidden_states))
  543. mixed_query_layer = self.query(hidden_states)
  544. query_layer = self.transpose_for_scores(mixed_query_layer)
  545. past_key_value = (key_layer, value_layer)
  546. # Take the dot product between "query" and "key" to get the raw attention scores.
  547. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  548. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  549. seq_length = hidden_states.size()[1]
  550. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  551. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  552. distance = position_ids_l - position_ids_r
  553. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  554. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  555. if self.position_embedding_type == "relative_key":
  556. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  557. attention_scores = attention_scores + relative_position_scores
  558. elif self.position_embedding_type == "relative_key_query":
  559. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  560. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  561. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  562. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  563. attention_scores_dtype = attention_scores.dtype
  564. if attention_mask is not None:
  565. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  566. attention_scores = attention_scores + attention_mask
  567. # Normalize the attention scores to probabilities.
  568. attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
  569. if is_cross_attention and self.save_attention:
  570. self.save_attention_map(attention_probs)
  571. attention_probs.register_hook(self.save_attn_gradients)
  572. # This is actually dropping out entire tokens to attend to, which might
  573. # seem a bit unusual, but is taken from the original Transformer paper.
  574. attention_probs_dropped = self.dropout(attention_probs)
  575. # Mask heads if we want to
  576. if head_mask is not None:
  577. attention_probs_dropped = attention_probs_dropped * head_mask
  578. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  579. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  580. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  581. context_layer = context_layer.view(*new_context_layer_shape)
  582. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  583. outputs = outputs + (past_key_value,)
  584. return outputs
  585. class InstructBlipVideoQFormerSelfOutput(nn.Module):
  586. def __init__(self, config):
  587. super().__init__()
  588. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  589. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  590. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  591. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  592. hidden_states = self.dense(hidden_states)
  593. hidden_states = self.dropout(hidden_states)
  594. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  595. return hidden_states
  596. class InstructBlipVideoQFormerAttention(nn.Module):
  597. def __init__(self, config, is_cross_attention=False):
  598. super().__init__()
  599. self.attention = InstructBlipVideoQFormerMultiHeadAttention(config, is_cross_attention)
  600. self.output = InstructBlipVideoQFormerSelfOutput(config)
  601. self.pruned_heads = set()
  602. def prune_heads(self, heads):
  603. if len(heads) == 0:
  604. return
  605. heads, index = find_pruneable_heads_and_indices(
  606. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  607. )
  608. # Prune linear layers
  609. self.attention.query = prune_linear_layer(self.attention.query, index)
  610. self.attention.key = prune_linear_layer(self.attention.key, index)
  611. self.attention.value = prune_linear_layer(self.attention.value, index)
  612. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  613. # Update hyper params and store pruned heads
  614. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  615. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  616. self.pruned_heads = self.pruned_heads.union(heads)
  617. def forward(
  618. self,
  619. hidden_states: torch.Tensor,
  620. attention_mask: Optional[torch.FloatTensor] = None,
  621. head_mask: Optional[torch.FloatTensor] = None,
  622. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  623. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  624. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  625. output_attentions: Optional[bool] = False,
  626. ) -> Tuple[torch.Tensor]:
  627. self_outputs = self.attention(
  628. hidden_states,
  629. attention_mask,
  630. head_mask,
  631. encoder_hidden_states,
  632. encoder_attention_mask,
  633. past_key_value,
  634. output_attentions,
  635. )
  636. attention_output = self.output(self_outputs[0], hidden_states)
  637. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  638. return outputs
  639. class InstructBlipVideoQFormerIntermediate(nn.Module):
  640. def __init__(self, config):
  641. super().__init__()
  642. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  643. if isinstance(config.hidden_act, str):
  644. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  645. else:
  646. self.intermediate_act_fn = config.hidden_act
  647. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  648. hidden_states = self.dense(hidden_states)
  649. hidden_states = self.intermediate_act_fn(hidden_states)
  650. return hidden_states
  651. class InstructBlipVideoQFormerOutput(nn.Module):
  652. def __init__(self, config):
  653. super().__init__()
  654. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  655. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  656. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  657. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  658. hidden_states = self.dense(hidden_states)
  659. hidden_states = self.dropout(hidden_states)
  660. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  661. return hidden_states
  662. class InstructBlipVideoQFormerLayer(nn.Module):
  663. def __init__(self, config, layer_idx):
  664. super().__init__()
  665. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  666. self.seq_len_dim = 1
  667. self.attention = InstructBlipVideoQFormerAttention(config)
  668. self.layer_idx = layer_idx
  669. if layer_idx % config.cross_attention_frequency == 0:
  670. self.crossattention = InstructBlipVideoQFormerAttention(config, is_cross_attention=True)
  671. self.has_cross_attention = True
  672. else:
  673. self.has_cross_attention = False
  674. self.intermediate = InstructBlipVideoQFormerIntermediate(config)
  675. self.output = InstructBlipVideoQFormerOutput(config)
  676. self.intermediate_query = InstructBlipVideoQFormerIntermediate(config)
  677. self.output_query = InstructBlipVideoQFormerOutput(config)
  678. def forward(
  679. self,
  680. hidden_states,
  681. attention_mask=None,
  682. head_mask=None,
  683. encoder_hidden_states=None,
  684. encoder_attention_mask=None,
  685. past_key_value=None,
  686. output_attentions=False,
  687. query_length=0,
  688. ):
  689. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  690. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  691. self_attention_outputs = self.attention(
  692. hidden_states,
  693. attention_mask,
  694. head_mask,
  695. output_attentions=output_attentions,
  696. past_key_value=self_attn_past_key_value,
  697. )
  698. attention_output = self_attention_outputs[0]
  699. outputs = self_attention_outputs[1:-1]
  700. present_key_value = self_attention_outputs[-1]
  701. if query_length > 0:
  702. query_attention_output = attention_output[:, :query_length, :]
  703. if self.has_cross_attention:
  704. if encoder_hidden_states is None:
  705. raise ValueError("encoder_hidden_states must be given for cross-attention layers")
  706. cross_attention_outputs = self.crossattention(
  707. query_attention_output,
  708. attention_mask,
  709. head_mask,
  710. encoder_hidden_states,
  711. encoder_attention_mask,
  712. output_attentions=output_attentions,
  713. )
  714. query_attention_output = cross_attention_outputs[0]
  715. # add cross attentions if we output attention weights
  716. outputs = outputs + cross_attention_outputs[1:-1]
  717. layer_output = apply_chunking_to_forward(
  718. self.feed_forward_chunk_query,
  719. self.chunk_size_feed_forward,
  720. self.seq_len_dim,
  721. query_attention_output,
  722. )
  723. if attention_output.shape[1] > query_length:
  724. layer_output_text = apply_chunking_to_forward(
  725. self.feed_forward_chunk,
  726. self.chunk_size_feed_forward,
  727. self.seq_len_dim,
  728. attention_output[:, query_length:, :],
  729. )
  730. layer_output = torch.cat([layer_output, layer_output_text], dim=1)
  731. else:
  732. layer_output = apply_chunking_to_forward(
  733. self.feed_forward_chunk,
  734. self.chunk_size_feed_forward,
  735. self.seq_len_dim,
  736. attention_output,
  737. )
  738. outputs = (layer_output,) + outputs
  739. outputs = outputs + (present_key_value,)
  740. return outputs
  741. def feed_forward_chunk(self, attention_output):
  742. intermediate_output = self.intermediate(attention_output)
  743. layer_output = self.output(intermediate_output, attention_output)
  744. return layer_output
  745. def feed_forward_chunk_query(self, attention_output):
  746. intermediate_output = self.intermediate_query(attention_output)
  747. layer_output = self.output_query(intermediate_output, attention_output)
  748. return layer_output
  749. class InstructBlipVideoQFormerEncoder(nn.Module):
  750. def __init__(self, config):
  751. super().__init__()
  752. self.config = config
  753. self.layer = nn.ModuleList(
  754. [InstructBlipVideoQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  755. )
  756. self.gradient_checkpointing = False
  757. def forward(
  758. self,
  759. hidden_states,
  760. attention_mask=None,
  761. head_mask=None,
  762. encoder_hidden_states=None,
  763. encoder_attention_mask=None,
  764. past_key_values=None,
  765. use_cache=None,
  766. output_attentions=False,
  767. output_hidden_states=False,
  768. return_dict=True,
  769. query_length=0,
  770. ):
  771. all_hidden_states = () if output_hidden_states else None
  772. all_self_attentions = () if output_attentions else None
  773. all_cross_attentions = () if output_attentions else None
  774. next_decoder_cache = () if use_cache else None
  775. for i in range(self.config.num_hidden_layers):
  776. layer_module = self.layer[i]
  777. if output_hidden_states:
  778. all_hidden_states = all_hidden_states + (hidden_states,)
  779. layer_head_mask = head_mask[i] if head_mask is not None else None
  780. past_key_value = past_key_values[i] if past_key_values is not None else None
  781. if getattr(self.config, "gradient_checkpointing", False) and self.training:
  782. if use_cache:
  783. logger.warning(
  784. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  785. )
  786. use_cache = False
  787. layer_outputs = self._gradient_checkpointing_func(
  788. layer_module.__call__,
  789. hidden_states,
  790. attention_mask,
  791. layer_head_mask,
  792. encoder_hidden_states,
  793. encoder_attention_mask,
  794. )
  795. else:
  796. layer_outputs = layer_module(
  797. hidden_states,
  798. attention_mask,
  799. layer_head_mask,
  800. encoder_hidden_states,
  801. encoder_attention_mask,
  802. past_key_value,
  803. output_attentions,
  804. query_length,
  805. )
  806. hidden_states = layer_outputs[0]
  807. if use_cache:
  808. next_decoder_cache += (layer_outputs[-1],)
  809. if output_attentions:
  810. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  811. if layer_module.has_cross_attention:
  812. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  813. if output_hidden_states:
  814. all_hidden_states = all_hidden_states + (hidden_states,)
  815. if not return_dict:
  816. return tuple(
  817. v
  818. for v in [
  819. hidden_states,
  820. next_decoder_cache,
  821. all_hidden_states,
  822. all_self_attentions,
  823. all_cross_attentions,
  824. ]
  825. if v is not None
  826. )
  827. return BaseModelOutputWithPastAndCrossAttentions(
  828. last_hidden_state=hidden_states,
  829. past_key_values=next_decoder_cache,
  830. hidden_states=all_hidden_states,
  831. attentions=all_self_attentions,
  832. cross_attentions=all_cross_attentions,
  833. )
  834. class InstructBlipVideoQFormerEmbeddings(nn.Module):
  835. """Construct the embeddings from word and position embeddings."""
  836. def __init__(self, config):
  837. super().__init__()
  838. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  839. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  840. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  841. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  842. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  843. self.register_buffer(
  844. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  845. )
  846. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  847. self.config = config
  848. def forward(
  849. self,
  850. input_ids=None,
  851. position_ids=None,
  852. query_embeds=None,
  853. past_key_values_length=0,
  854. ):
  855. if input_ids is not None:
  856. seq_length = input_ids.size()[1]
  857. else:
  858. seq_length = 0
  859. if position_ids is None:
  860. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
  861. if input_ids is not None:
  862. embeddings = self.word_embeddings(input_ids)
  863. if self.position_embedding_type == "absolute":
  864. position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
  865. embeddings = embeddings + position_embeddings
  866. if query_embeds is not None:
  867. embeddings = torch.cat((query_embeds, embeddings), dim=1)
  868. else:
  869. embeddings = query_embeds
  870. embeddings = embeddings.to(self.layernorm.weight.dtype)
  871. embeddings = self.layernorm(embeddings)
  872. embeddings = self.dropout(embeddings)
  873. return embeddings
  874. class InstructBlipVideoQFormerModel(InstructBlipVideoPreTrainedModel):
  875. """
  876. Querying Transformer (Q-Former), used in InstructBlipVideo. Slightly modified from BLIP-2 as it also takes the
  877. instruction as input.
  878. """
  879. def __init__(self, config: InstructBlipVideoQFormerConfig):
  880. super().__init__(config)
  881. self.config = config
  882. self.embeddings = InstructBlipVideoQFormerEmbeddings(config)
  883. self.encoder = InstructBlipVideoQFormerEncoder(config)
  884. self.post_init()
  885. def get_input_embeddings(self):
  886. return self.embeddings.word_embeddings
  887. def set_input_embeddings(self, value):
  888. self.embeddings.word_embeddings = value
  889. def _prune_heads(self, heads_to_prune):
  890. """
  891. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  892. class PreTrainedModel
  893. """
  894. for layer, heads in heads_to_prune.items():
  895. self.encoder.layer[layer].attention.prune_heads(heads)
  896. def get_extended_attention_mask(
  897. self,
  898. attention_mask: torch.Tensor,
  899. input_shape: Tuple[int],
  900. device: torch.device,
  901. has_query: bool = False,
  902. ) -> torch.Tensor:
  903. """
  904. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  905. Arguments:
  906. attention_mask (`torch.Tensor`):
  907. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  908. input_shape (`Tuple[int]`):
  909. The shape of the input to the model.
  910. device: (`torch.device`):
  911. The device of the input to the model.
  912. Returns:
  913. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  914. """
  915. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  916. # ourselves in which case we just need to make it broadcastable to all heads.
  917. if attention_mask.dim() == 3:
  918. extended_attention_mask = attention_mask[:, None, :, :]
  919. elif attention_mask.dim() == 2:
  920. # Provided a padding mask of dimensions [batch_size, seq_length]
  921. # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  922. extended_attention_mask = attention_mask[:, None, None, :]
  923. else:
  924. raise ValueError(
  925. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})",
  926. )
  927. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  928. # masked positions, this operation will create a tensor which is 0.0 for
  929. # positions we want to attend and -10000.0 for masked positions.
  930. # Since we are adding it to the raw scores before the softmax, this is
  931. # effectively the same as removing these entirely.
  932. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  933. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  934. return extended_attention_mask
  935. def forward(
  936. self,
  937. input_ids: torch.LongTensor,
  938. attention_mask: Optional[torch.FloatTensor] = None,
  939. position_ids: Optional[torch.LongTensor] = None,
  940. query_embeds: Optional[torch.Tensor] = None,
  941. head_mask: Optional[torch.FloatTensor] = None,
  942. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  943. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  944. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  945. use_cache: Optional[bool] = None,
  946. output_attentions: Optional[bool] = None,
  947. output_hidden_states: Optional[bool] = None,
  948. return_dict: Optional[bool] = None,
  949. ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  950. r"""
  951. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  952. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  953. the model is configured as a decoder.
  954. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  955. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  956. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  957. - 1 for tokens that are **not masked**,
  958. - 0 for tokens that are **masked**.
  959. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
  960. shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
  961. value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
  962. used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
  963. value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
  964. `(batch_size, sequence_length)`.
  965. use_cache (`bool`, *optional*):
  966. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  967. `past_key_values`).
  968. """
  969. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  970. output_hidden_states = (
  971. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  972. )
  973. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  974. if input_ids is None and query_embeds is None:
  975. raise ValueError("You have to specify query_embeds when input_ids is None")
  976. # past_key_values_length
  977. past_key_values_length = (
  978. past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
  979. )
  980. query_length = query_embeds.shape[1] if query_embeds is not None else 0
  981. embedding_output = self.embeddings(
  982. input_ids=input_ids,
  983. position_ids=position_ids,
  984. query_embeds=query_embeds,
  985. past_key_values_length=past_key_values_length,
  986. )
  987. input_shape = embedding_output.size()[:-1]
  988. batch_size, seq_length = input_shape
  989. device = embedding_output.device
  990. if attention_mask is None:
  991. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  992. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  993. # ourselves in which case we just need to make it broadcastable to all heads.
  994. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
  995. # If a 2D or 3D attention mask is provided for the cross-attention
  996. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  997. if encoder_hidden_states is not None:
  998. if isinstance(encoder_hidden_states, list):
  999. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  1000. else:
  1001. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1002. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1003. if isinstance(encoder_attention_mask, list):
  1004. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  1005. elif encoder_attention_mask is None:
  1006. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  1007. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1008. else:
  1009. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1010. else:
  1011. encoder_extended_attention_mask = None
  1012. # Prepare head mask if needed
  1013. # 1.0 in head_mask indicate we keep the head
  1014. # attention_probs has shape bsz x n_heads x N x N
  1015. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  1016. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1017. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  1018. encoder_outputs = self.encoder(
  1019. embedding_output,
  1020. attention_mask=extended_attention_mask,
  1021. head_mask=head_mask,
  1022. encoder_hidden_states=encoder_hidden_states,
  1023. encoder_attention_mask=encoder_extended_attention_mask,
  1024. past_key_values=past_key_values,
  1025. use_cache=use_cache,
  1026. output_attentions=output_attentions,
  1027. output_hidden_states=output_hidden_states,
  1028. return_dict=return_dict,
  1029. query_length=query_length,
  1030. )
  1031. sequence_output = encoder_outputs[0]
  1032. pooled_output = sequence_output[:, 0, :]
  1033. if not return_dict:
  1034. return (sequence_output, pooled_output) + encoder_outputs[1:]
  1035. return BaseModelOutputWithPoolingAndCrossAttentions(
  1036. last_hidden_state=sequence_output,
  1037. pooler_output=pooled_output,
  1038. past_key_values=encoder_outputs.past_key_values,
  1039. hidden_states=encoder_outputs.hidden_states,
  1040. attentions=encoder_outputs.attentions,
  1041. cross_attentions=encoder_outputs.cross_attentions,
  1042. )
  1043. @add_start_docstrings(
  1044. """
  1045. InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
  1046. encoder, Querying Transformer (Q-Former) and a language model.
  1047. One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
  1048. the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
  1049. """,
  1050. INSTRUCTBLIPVIDEO_START_DOCSTRING,
  1051. )
  1052. class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
  1053. config_class = InstructBlipVideoConfig
  1054. main_input_name = "pixel_values"
  1055. def __init__(self, config: InstructBlipVideoConfig):
  1056. super().__init__(config)
  1057. self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
  1058. self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
  1059. self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
  1060. self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
  1061. if config.use_decoder_only_language_model:
  1062. language_model = AutoModelForCausalLM.from_config(config.text_config)
  1063. else:
  1064. language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
  1065. if language_model._no_split_modules is not None:
  1066. self._no_split_modules.extend(language_model._no_split_modules)
  1067. if language_model._keep_in_fp32_modules is not None:
  1068. self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
  1069. self.language_model = language_model
  1070. # Initialize weights and apply final processing
  1071. self.post_init()
  1072. def get_input_embeddings(self):
  1073. return self.language_model.get_input_embeddings()
  1074. def set_input_embeddings(self, value):
  1075. self.language_model.set_input_embeddings(value)
  1076. def set_output_embeddings(self, new_embeddings):
  1077. self.language_model.set_output_embeddings(new_embeddings)
  1078. def get_output_embeddings(self) -> nn.Module:
  1079. return self.language_model.get_output_embeddings()
  1080. def get_encoder(self):
  1081. return self.language_model.get_encoder()
  1082. def get_decoder(self):
  1083. return self.language_model.get_decoder()
  1084. def _tie_weights(self):
  1085. if not self.config.use_decoder_only_language_model:
  1086. self.language_model.encoder.embed_tokens = self.language_model.shared
  1087. self.language_model.decoder.embed_tokens = self.language_model.shared
  1088. def _preprocess_accelerate(self):
  1089. r"""
  1090. Some pre-processing hacks to make the model `accelerate` compatible. Check
  1091. https://github.com/huggingface/transformers/pull/21707 for more details.
  1092. """
  1093. hf_device_map = self.hf_device_map
  1094. if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
  1095. # warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
  1096. logger.warning(
  1097. "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
  1098. " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
  1099. " Please pass a `device_map` that contains `language_model` to remove this warning."
  1100. " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
  1101. " more details on creating a `device_map` for large models.",
  1102. )
  1103. if hasattr(self.language_model, "_hf_hook"):
  1104. self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
  1105. @add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
  1106. @replace_return_docstrings(
  1107. output_type=InstructBlipVideoForConditionalGenerationModelOutput, config_class=InstructBlipVideoVisionConfig
  1108. )
  1109. def forward(
  1110. self,
  1111. pixel_values: torch.FloatTensor,
  1112. qformer_input_ids: torch.FloatTensor,
  1113. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1114. input_ids: Optional[torch.FloatTensor] = None,
  1115. attention_mask: Optional[torch.LongTensor] = None,
  1116. decoder_input_ids: Optional[torch.LongTensor] = None,
  1117. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1118. output_attentions: Optional[bool] = None,
  1119. output_hidden_states: Optional[bool] = None,
  1120. labels: Optional[torch.LongTensor] = None,
  1121. return_dict: Optional[bool] = None,
  1122. interpolate_pos_encoding: bool = False,
  1123. ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  1124. r"""
  1125. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1126. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1127. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1128. config.vocab_size]`
  1129. Returns:
  1130. Examples:
  1131. ```python
  1132. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  1133. >>> import torch
  1134. >>> from huggingface_hub import hf_hub_download
  1135. >>> import av
  1136. >>> import numpy as np
  1137. >>> def read_video_pyav(container, indices):
  1138. ... '''
  1139. ... Decode the video with PyAV decoder.
  1140. ... Args:
  1141. ... container (`av.container.input.InputContainer`): PyAV container.
  1142. ... indices (`List[int]`): List of frame indices to decode.
  1143. ... Returns:
  1144. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  1145. ... '''
  1146. ... frames = []
  1147. ... container.seek(0)
  1148. ... start_index = indices[0]
  1149. ... end_index = indices[-1]
  1150. ... for i, frame in enumerate(container.decode(video=0)):
  1151. ... if i > end_index:
  1152. ... break
  1153. ... if i >= start_index and i in indices:
  1154. ... frames.append(frame)
  1155. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  1156. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  1157. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  1158. >>> file_path = hf_hub_download(
  1159. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  1160. ... )
  1161. >>> container = av.open(file_path)
  1162. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  1163. >>> total_frames = container.streams.video[0].frames
  1164. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  1165. >>> clip = read_video_pyav(container, indices)
  1166. >>> prompt = "What is happening in the video?"
  1167. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  1168. >>> outputs = model.generate(
  1169. ... **inputs,
  1170. ... do_sample=False,
  1171. ... num_beams=5,
  1172. ... max_length=256,
  1173. ... repetition_penalty=1.5,
  1174. ... length_penalty=1.0,
  1175. ... )
  1176. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  1177. >>> print(generated_text)
  1178. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  1179. ```"""
  1180. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1181. # step 1: forward the images through the vision encoder,
  1182. # we process in a batched way, later unbatch it back (video has frames=4 always)
  1183. batch_size, frames, channel, height, width = pixel_values.shape
  1184. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  1185. vision_outputs = self.vision_model(
  1186. pixel_values=pixel_values,
  1187. output_attentions=output_attentions,
  1188. output_hidden_states=output_hidden_states,
  1189. return_dict=return_dict,
  1190. interpolate_pos_encoding=interpolate_pos_encoding,
  1191. )
  1192. image_embeds = vision_outputs[0]
  1193. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  1194. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1195. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  1196. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1197. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1198. if qformer_attention_mask is None:
  1199. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1200. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  1201. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  1202. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1203. query_outputs = self.qformer(
  1204. input_ids=qformer_input_ids,
  1205. attention_mask=qformer_attention_mask,
  1206. query_embeds=query_tokens,
  1207. encoder_hidden_states=image_embeds,
  1208. encoder_attention_mask=image_attention_mask,
  1209. output_attentions=output_attentions,
  1210. output_hidden_states=output_hidden_states,
  1211. return_dict=return_dict,
  1212. )
  1213. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  1214. # step 3: use the language model, conditioned on the query outputs and the prompt
  1215. language_model_inputs = self.language_projection(query_output)
  1216. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  1217. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  1218. language_model_attention_mask = torch.ones(
  1219. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  1220. )
  1221. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  1222. if attention_mask is None:
  1223. attention_mask = torch.ones_like(input_ids)
  1224. # if the model already has "video_token_index" then the input is expanded to account for image embeds
  1225. # otherwise we expand manually by concatenating
  1226. if getattr(self.config, "video_token_index", None) is not None:
  1227. special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  1228. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  1229. else:
  1230. logger.warning_once(
  1231. "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
  1232. "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
  1233. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  1234. )
  1235. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  1236. attention_mask = torch.cat(
  1237. [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
  1238. )
  1239. if self.config.use_decoder_only_language_model:
  1240. outputs = self.language_model(
  1241. inputs_embeds=inputs_embeds,
  1242. attention_mask=attention_mask,
  1243. output_attentions=output_attentions,
  1244. output_hidden_states=output_hidden_states,
  1245. return_dict=return_dict,
  1246. )
  1247. logits = outputs.logits if return_dict else outputs[0]
  1248. loss = None
  1249. # we compute the loss here since we need to take into account the sequence length of the query embeds
  1250. if labels is not None:
  1251. labels = labels.to(logits.device)
  1252. logits = logits[:, -labels.size(1) :, :]
  1253. # Shift so that tokens < n predict n
  1254. shift_logits = logits[..., :-1, :].contiguous()
  1255. shift_labels = labels[..., 1:].contiguous().to(logits.device)
  1256. # Flatten the tokens
  1257. loss_fct = CrossEntropyLoss(reduction="mean")
  1258. loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
  1259. else:
  1260. outputs = self.language_model(
  1261. inputs_embeds=inputs_embeds,
  1262. attention_mask=attention_mask,
  1263. decoder_input_ids=decoder_input_ids,
  1264. decoder_attention_mask=decoder_attention_mask,
  1265. output_attentions=output_attentions,
  1266. output_hidden_states=output_hidden_states,
  1267. return_dict=return_dict,
  1268. labels=labels,
  1269. )
  1270. loss = outputs.loss if return_dict else outputs[0]
  1271. logits = outputs.logits if return_dict else outputs[1]
  1272. if not return_dict:
  1273. output = (logits, vision_outputs, query_outputs, outputs)
  1274. return ((loss,) + output) if loss is not None else output
  1275. return InstructBlipVideoForConditionalGenerationModelOutput(
  1276. loss=loss,
  1277. logits=logits,
  1278. vision_outputs=vision_outputs,
  1279. qformer_outputs=query_outputs,
  1280. language_model_outputs=outputs,
  1281. )
  1282. @torch.no_grad()
  1283. def generate(
  1284. self,
  1285. pixel_values: torch.FloatTensor,
  1286. qformer_input_ids: Optional[torch.LongTensor] = None,
  1287. qformer_attention_mask: Optional[torch.LongTensor] = None,
  1288. input_ids: Optional[torch.LongTensor] = None,
  1289. attention_mask: Optional[torch.LongTensor] = None,
  1290. interpolate_pos_encoding: bool = False,
  1291. **generate_kwargs,
  1292. ) -> torch.LongTensor:
  1293. r"""
  1294. Overrides `generate` function to be able to use the model as a conditional generator.
  1295. Args:
  1296. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  1297. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  1298. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1299. The sequence used as a prompt to be fed to the Q-Former module.
  1300. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1301. Mask to avoid performing attention on padding token indices.
  1302. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1303. The sequence used as a prompt for the generation.
  1304. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  1305. Mask to avoid performing attention on padding token indices.
  1306. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1307. Whether to interpolate the positional encoding of the image embeddings.
  1308. Returns:
  1309. captions (list): A list of strings of length batch_size * num_captions.
  1310. """
  1311. if hasattr(self, "hf_device_map"):
  1312. # preprocess for `accelerate`
  1313. self._preprocess_accelerate()
  1314. # we process in a batched way, later unbatch it back (video has frames=4)
  1315. batch_size, frames, channel, height, width = pixel_values.shape
  1316. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  1317. image_embeds = self.vision_model(
  1318. pixel_values,
  1319. return_dict=True,
  1320. interpolate_pos_encoding=interpolate_pos_encoding,
  1321. ).last_hidden_state
  1322. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1323. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  1324. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  1325. if qformer_attention_mask is None:
  1326. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  1327. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  1328. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  1329. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  1330. query_outputs = self.qformer(
  1331. input_ids=qformer_input_ids,
  1332. attention_mask=qformer_attention_mask,
  1333. query_embeds=query_tokens,
  1334. encoder_hidden_states=image_embeds,
  1335. encoder_attention_mask=image_attention_mask,
  1336. return_dict=True,
  1337. )
  1338. query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
  1339. language_model_inputs = self.language_projection(query_output)
  1340. # unbatch the embeddings back by moving frames to seq-len
  1341. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  1342. language_attention_mask = torch.ones(
  1343. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  1344. )
  1345. if input_ids is None:
  1346. input_ids = (
  1347. torch.LongTensor([[self.config.text_config.bos_token_id]])
  1348. .repeat(batch_size, 1)
  1349. .to(image_embeds.device)
  1350. )
  1351. if attention_mask is None:
  1352. attention_mask = torch.ones_like(input_ids)
  1353. inputs_embeds = self.get_input_embeddings()(input_ids)
  1354. # if the model already has "video_token_index" then the input is expanded to account for image embeds
  1355. # otherwise we expand manually by concatenating
  1356. if getattr(self.config, "video_token_index", None) is not None:
  1357. special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  1358. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  1359. else:
  1360. logger.warning_once(
  1361. "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
  1362. "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
  1363. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  1364. )
  1365. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  1366. attention_mask = torch.cat(
  1367. [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
  1368. )
  1369. # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
  1370. # -1 is to account for the prepended BOS after `generate.`
  1371. if not self.language_model.config.is_encoder_decoder:
  1372. generate_kwargs["max_length"] = (
  1373. generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
  1374. )
  1375. generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
  1376. outputs = self.language_model.generate(
  1377. inputs_embeds=inputs_embeds,
  1378. attention_mask=attention_mask,
  1379. **generate_kwargs,
  1380. )
  1381. # this is a temporary workaround to be consistent with other generation models and
  1382. # have BOS as the first token, even though under the hood we are calling LM with embeds
  1383. if not self.language_model.config.is_encoder_decoder:
  1384. # the InstructBLIP authors used inconsistent tokenizer/model files during training,
  1385. # with the tokenizer's bos token being set to </s> which has ID=2,
  1386. # whereas the model's text config has bos token id = 0
  1387. bos_token_id = (
  1388. 2
  1389. if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
  1390. else self.config.text_config.bos_token_id
  1391. )
  1392. bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
  1393. if not isinstance(outputs, torch.Tensor):
  1394. outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
  1395. else:
  1396. outputs = torch.cat([bos_tokens, outputs], dim=-1)
  1397. return outputs