modular_instructblipvideo.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. # coding=utf-8
  2. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from dataclasses import dataclass
  16. from typing import Optional, Tuple, Union
  17. import torch
  18. import torch.utils.checkpoint
  19. from torch.nn import CrossEntropyLoss
  20. from transformers.models.instructblip.configuration_instructblip import (
  21. InstructBlipQFormerConfig,
  22. InstructBlipVisionConfig,
  23. )
  24. from transformers.models.instructblip.modeling_instructblip import (
  25. InstructBlipForConditionalGeneration,
  26. InstructBlipForConditionalGenerationModelOutput,
  27. )
  28. from ...configuration_utils import PretrainedConfig
  29. from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
  30. from ...utils import logging
  31. from ..auto import CONFIG_MAPPING
  32. logger = logging.get_logger(__name__)
  33. class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
  34. pass
  35. class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
  36. pass
  37. class InstructBlipVideoConfig(PretrainedConfig):
  38. r"""
  39. [`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
  40. [`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
  41. arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
  42. the defaults will yield a similar configuration to that of the Instructblipvideo
  43. [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
  44. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  45. documentation from [`PretrainedConfig`] for more information.
  46. Args:
  47. vision_config (`dict`, *optional*):
  48. Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
  49. qformer_config (`dict`, *optional*):
  50. Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
  51. text_config (`dict`, *optional*):
  52. Dictionary of configuration options used to initialize any [`PretrainedConfig`].
  53. num_query_tokens (`int`, *optional*, defaults to 32):
  54. The number of query tokens passed through the Transformer.
  55. video_token_index (`int`, *optional*):
  56. Token index of special video token.
  57. kwargs (*optional*):
  58. Dictionary of keyword arguments.
  59. Example:
  60. ```python
  61. >>> from transformers import (
  62. ... InstructBlipVideoVisionConfig,
  63. ... InstructBlipVideoQFormerConfig,
  64. ... OPTConfig,
  65. ... InstructBlipVideoConfig,
  66. ... InstructBlipVideoForConditionalGeneration,
  67. ... )
  68. >>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
  69. >>> configuration = InstructBlipVideoConfig()
  70. >>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
  71. >>> model = InstructBlipVideoForConditionalGeneration(configuration)
  72. >>> # Accessing the model configuration
  73. >>> configuration = model.config
  74. >>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
  75. >>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
  76. >>> vision_config = InstructBlipVideoVisionConfig()
  77. >>> qformer_config = InstructBlipVideoQFormerConfig()
  78. >>> text_config = OPTConfig()
  79. >>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
  80. ```"""
  81. model_type = "instructblipvideo"
  82. def __init__(
  83. self,
  84. vision_config=None,
  85. qformer_config=None,
  86. text_config=None,
  87. num_query_tokens=32,
  88. video_token_index=None,
  89. **kwargs,
  90. ):
  91. super().__init__(**kwargs)
  92. if vision_config is None:
  93. vision_config = {}
  94. logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
  95. if qformer_config is None:
  96. qformer_config = {}
  97. logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
  98. if text_config is None:
  99. text_config = {}
  100. logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
  101. self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
  102. self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
  103. text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
  104. self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
  105. self.tie_word_embeddings = self.text_config.tie_word_embeddings
  106. self.is_encoder_decoder = self.text_config.is_encoder_decoder
  107. self.num_query_tokens = num_query_tokens
  108. self.video_token_index = video_token_index
  109. self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
  110. self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
  111. self.initializer_factor = 1.0
  112. self.initializer_range = 0.02
  113. @classmethod
  114. def from_vision_qformer_text_configs(
  115. cls,
  116. vision_config: InstructBlipVideoVisionConfig,
  117. qformer_config: InstructBlipVideoQFormerConfig,
  118. text_config: PretrainedConfig,
  119. **kwargs,
  120. ):
  121. r"""
  122. Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a InstructBlipVideo vision model, Q-Former and
  123. language model configurations.
  124. Returns:
  125. [`InstructBlipVideoConfig`]: An instance of a configuration object
  126. """
  127. return cls(
  128. vision_config=vision_config.to_dict(),
  129. qformer_config=qformer_config.to_dict(),
  130. text_config=text_config.to_dict(),
  131. **kwargs,
  132. )
  133. @dataclass
  134. class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
  135. pass
  136. class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
  137. def forward(
  138. self,
  139. pixel_values: torch.FloatTensor,
  140. qformer_input_ids: torch.FloatTensor,
  141. qformer_attention_mask: Optional[torch.LongTensor] = None,
  142. input_ids: Optional[torch.FloatTensor] = None,
  143. attention_mask: Optional[torch.LongTensor] = None,
  144. decoder_input_ids: Optional[torch.LongTensor] = None,
  145. decoder_attention_mask: Optional[torch.LongTensor] = None,
  146. output_attentions: Optional[bool] = None,
  147. output_hidden_states: Optional[bool] = None,
  148. labels: Optional[torch.LongTensor] = None,
  149. return_dict: Optional[bool] = None,
  150. interpolate_pos_encoding: bool = False,
  151. ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
  152. r"""
  153. ```python
  154. >>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
  155. >>> import torch
  156. >>> from huggingface_hub import hf_hub_download
  157. >>> import av
  158. >>> import numpy as np
  159. >>> def read_video_pyav(container, indices):
  160. ... '''
  161. ... Decode the video with PyAV decoder.
  162. ... Args:
  163. ... container (`av.container.input.InputContainer`): PyAV container.
  164. ... indices (`List[int]`): List of frame indices to decode.
  165. ... Returns:
  166. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  167. ... '''
  168. ... frames = []
  169. ... container.seek(0)
  170. ... start_index = indices[0]
  171. ... end_index = indices[-1]
  172. ... for i, frame in enumerate(container.decode(video=0)):
  173. ... if i > end_index:
  174. ... break
  175. ... if i >= start_index and i in indices:
  176. ... frames.append(frame)
  177. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  178. >>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
  179. >>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
  180. >>> file_path = hf_hub_download(
  181. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  182. ... )
  183. >>> container = av.open(file_path)
  184. >>> # sample uniformly 4 frames from the videWhy is this video funny?o
  185. >>> total_frames = container.streams.video[0].frames
  186. >>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
  187. >>> clip = read_video_pyav(container, indices)
  188. >>> prompt = "What is happening in the video?"
  189. >>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device)
  190. >>> outputs = model.generate(
  191. ... **inputs,
  192. ... do_sample=False,
  193. ... num_beams=5,
  194. ... max_length=256,
  195. ... repetition_penalty=1.5,
  196. ... length_penalty=1.0,
  197. ... )
  198. >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
  199. >>> print(generated_text)
  200. "A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
  201. ```"""
  202. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  203. # step 1: forward the images through the vision encoder,
  204. # we process in a batched way, later unbatch it back (video has frames=4 always)
  205. batch_size, frames, channel, height, width = pixel_values.shape
  206. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  207. vision_outputs = self.vision_model(
  208. pixel_values=pixel_values,
  209. output_attentions=output_attentions,
  210. output_hidden_states=output_hidden_states,
  211. return_dict=return_dict,
  212. interpolate_pos_encoding=interpolate_pos_encoding,
  213. )
  214. image_embeds = vision_outputs[0]
  215. # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
  216. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  217. # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
  218. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  219. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  220. if qformer_attention_mask is None:
  221. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  222. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  223. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  224. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  225. query_outputs = self.qformer(
  226. input_ids=qformer_input_ids,
  227. attention_mask=qformer_attention_mask,
  228. query_embeds=query_tokens,
  229. encoder_hidden_states=image_embeds,
  230. encoder_attention_mask=image_attention_mask,
  231. output_attentions=output_attentions,
  232. output_hidden_states=output_hidden_states,
  233. return_dict=return_dict,
  234. )
  235. query_output = query_outputs[0][:, : query_tokens.size(1), :]
  236. # step 3: use the language model, conditioned on the query outputs and the prompt
  237. language_model_inputs = self.language_projection(query_output)
  238. # unbatch inputs back, each video-frame gets `num_query_tokens` seq length
  239. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  240. language_model_attention_mask = torch.ones(
  241. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  242. )
  243. inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
  244. if attention_mask is None:
  245. attention_mask = torch.ones_like(input_ids)
  246. # if the model already has "video_token_index" then the input is expanded to account for image embeds
  247. # otherwise we expand manually by concatenating
  248. if getattr(self.config, "video_token_index", None) is not None:
  249. special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  250. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  251. else:
  252. logger.warning_once(
  253. "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
  254. "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
  255. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  256. )
  257. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  258. attention_mask = torch.cat(
  259. [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
  260. )
  261. if self.config.use_decoder_only_language_model:
  262. outputs = self.language_model(
  263. inputs_embeds=inputs_embeds,
  264. attention_mask=attention_mask,
  265. output_attentions=output_attentions,
  266. output_hidden_states=output_hidden_states,
  267. return_dict=return_dict,
  268. )
  269. logits = outputs.logits if return_dict else outputs[0]
  270. loss = None
  271. # we compute the loss here since we need to take into account the sequence length of the query embeds
  272. if labels is not None:
  273. labels = labels.to(logits.device)
  274. logits = logits[:, -labels.size(1) :, :]
  275. # Shift so that tokens < n predict n
  276. shift_logits = logits[..., :-1, :].contiguous()
  277. shift_labels = labels[..., 1:].contiguous().to(logits.device)
  278. # Flatten the tokens
  279. loss_fct = CrossEntropyLoss(reduction="mean")
  280. loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
  281. else:
  282. outputs = self.language_model(
  283. inputs_embeds=inputs_embeds,
  284. attention_mask=attention_mask,
  285. decoder_input_ids=decoder_input_ids,
  286. decoder_attention_mask=decoder_attention_mask,
  287. output_attentions=output_attentions,
  288. output_hidden_states=output_hidden_states,
  289. return_dict=return_dict,
  290. labels=labels,
  291. )
  292. loss = outputs.loss if return_dict else outputs[0]
  293. logits = outputs.logits if return_dict else outputs[1]
  294. if not return_dict:
  295. output = (logits, vision_outputs, query_outputs, outputs)
  296. return ((loss,) + output) if loss is not None else output
  297. return InstructBlipVideoForConditionalGenerationModelOutput(
  298. loss=loss,
  299. logits=logits,
  300. vision_outputs=vision_outputs,
  301. qformer_outputs=query_outputs,
  302. language_model_outputs=outputs,
  303. )
  304. @torch.no_grad()
  305. def generate(
  306. self,
  307. pixel_values: torch.FloatTensor,
  308. qformer_input_ids: Optional[torch.LongTensor] = None,
  309. qformer_attention_mask: Optional[torch.LongTensor] = None,
  310. input_ids: Optional[torch.LongTensor] = None,
  311. attention_mask: Optional[torch.LongTensor] = None,
  312. interpolate_pos_encoding: bool = False,
  313. **generate_kwargs,
  314. ) -> torch.LongTensor:
  315. r"""
  316. Overrides `generate` function to be able to use the model as a conditional generator.
  317. Args:
  318. pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
  319. (batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
  320. qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  321. The sequence used as a prompt to be fed to the Q-Former module.
  322. qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  323. Mask to avoid performing attention on padding token indices.
  324. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  325. The sequence used as a prompt for the generation.
  326. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
  327. Mask to avoid performing attention on padding token indices.
  328. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  329. Whether to interpolate the positional encoding of the image embeddings.
  330. Returns:
  331. captions (list): A list of strings of length batch_size * num_captions.
  332. """
  333. if hasattr(self, "hf_device_map"):
  334. # preprocess for `accelerate`
  335. self._preprocess_accelerate()
  336. # we process in a batched way, later unbatch it back (video has frames=4)
  337. batch_size, frames, channel, height, width = pixel_values.shape
  338. pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
  339. image_embeds = self.vision_model(
  340. pixel_values,
  341. return_dict=True,
  342. interpolate_pos_encoding=interpolate_pos_encoding,
  343. ).last_hidden_state
  344. image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
  345. query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
  346. query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
  347. if qformer_attention_mask is None:
  348. qformer_attention_mask = torch.ones_like(qformer_input_ids)
  349. qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
  350. qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
  351. qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
  352. query_outputs = self.qformer(
  353. input_ids=qformer_input_ids,
  354. attention_mask=qformer_attention_mask,
  355. query_embeds=query_tokens,
  356. encoder_hidden_states=image_embeds,
  357. encoder_attention_mask=image_attention_mask,
  358. return_dict=True,
  359. )
  360. query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
  361. language_model_inputs = self.language_projection(query_output)
  362. # unbatch the embeddings back by moving frames to seq-len
  363. language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
  364. language_attention_mask = torch.ones(
  365. language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
  366. )
  367. if input_ids is None:
  368. input_ids = (
  369. torch.LongTensor([[self.config.text_config.bos_token_id]])
  370. .repeat(batch_size, 1)
  371. .to(image_embeds.device)
  372. )
  373. if attention_mask is None:
  374. attention_mask = torch.ones_like(input_ids)
  375. inputs_embeds = self.get_input_embeddings()(input_ids)
  376. # if the model already has "video_token_index" then the input is expanded to account for image embeds
  377. # otherwise we expand manually by concatenating
  378. if getattr(self.config, "video_token_index", None) is not None:
  379. special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
  380. inputs_embeds[special_image_mask] = language_model_inputs.flatten()
  381. else:
  382. logger.warning_once(
  383. "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
  384. "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
  385. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  386. )
  387. inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
  388. attention_mask = torch.cat(
  389. [language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1
  390. )
  391. # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
  392. # -1 is to account for the prepended BOS after `generate.`
  393. if not self.language_model.config.is_encoder_decoder:
  394. generate_kwargs["max_length"] = (
  395. generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
  396. )
  397. generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
  398. outputs = self.language_model.generate(
  399. inputs_embeds=inputs_embeds,
  400. attention_mask=attention_mask,
  401. **generate_kwargs,
  402. )
  403. # this is a temporary workaround to be consistent with other generation models and
  404. # have BOS as the first token, even though under the hood we are calling LM with embeds
  405. if not self.language_model.config.is_encoder_decoder:
  406. # the InstructBLIP authors used inconsistent tokenizer/model files during training,
  407. # with the tokenizer's bos token being set to </s> which has ID=2,
  408. # whereas the model's text config has bos token id = 0
  409. bos_token_id = (
  410. 2
  411. if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
  412. else self.config.text_config.bos_token_id
  413. )
  414. bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
  415. if not isinstance(outputs, torch.Tensor):
  416. outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
  417. else:
  418. outputs = torch.cat([bos_tokens, outputs], dim=-1)
  419. return outputs