processing_instructblipvideo.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
  17. """
  18. import os
  19. from typing import List, Optional, Union
  20. from ...image_processing_utils import BatchFeature
  21. from ...image_utils import VideoInput
  22. from ...processing_utils import ProcessorMixin
  23. from ...tokenization_utils_base import (
  24. AddedToken,
  25. BatchEncoding,
  26. PaddingStrategy,
  27. PreTokenizedInput,
  28. TextInput,
  29. TruncationStrategy,
  30. )
  31. from ...utils import TensorType, logging
  32. from ..auto import AutoTokenizer
  33. logger = logging.get_logger(__name__)
  34. class InstructBlipVideoProcessor(ProcessorMixin):
  35. r"""
  36. Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single
  37. processor.
  38. [`InstructBlipVideoProcessor`] offers all the functionalities of [`InstructBlipVideoImageProcessor`] and [`AutoTokenizer`]. See the
  39. docstring of [`~InstructBlipVideoProcessor.__call__`] and [`~InstructBlipVideoProcessor.decode`] for more information.
  40. Args:
  41. image_processor (`InstructBlipVideoImageProcessor`):
  42. An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input.
  43. tokenizer (`AutoTokenizer`):
  44. An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
  45. qformer_tokenizer (`AutoTokenizer`):
  46. An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
  47. num_query_tokens (`int`, *optional*):
  48. Number of tokens used by the Qformer as queries, should be same as in model's config.
  49. """
  50. attributes = ["image_processor", "tokenizer", "qformer_tokenizer"]
  51. valid_kwargs = ["num_query_tokens"]
  52. image_processor_class = "InstructBlipVideoImageProcessor"
  53. tokenizer_class = "AutoTokenizer"
  54. qformer_tokenizer_class = "AutoTokenizer"
  55. def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
  56. self.video_token = AddedToken("<video>", normalized=False, special=True)
  57. tokenizer.add_tokens([self.video_token], special_tokens=True)
  58. self.num_query_tokens = num_query_tokens
  59. super().__init__(image_processor, tokenizer, qformer_tokenizer)
  60. def __call__(
  61. self,
  62. images: VideoInput = None,
  63. text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
  64. add_special_tokens: bool = True,
  65. padding: Union[bool, str, PaddingStrategy] = False,
  66. truncation: Union[bool, str, TruncationStrategy] = None,
  67. max_length: Optional[int] = None,
  68. stride: int = 0,
  69. pad_to_multiple_of: Optional[int] = None,
  70. return_attention_mask: Optional[bool] = None,
  71. return_overflowing_tokens: bool = False,
  72. return_special_tokens_mask: bool = False,
  73. return_offsets_mapping: bool = False,
  74. return_token_type_ids: bool = False,
  75. return_length: bool = False,
  76. verbose: bool = True,
  77. return_tensors: Optional[Union[str, TensorType]] = None,
  78. **kwargs,
  79. ) -> BatchFeature:
  80. """
  81. This method uses [`InstructBlipVideoImageProcessor.__call__`] method to prepare image(s) or video(s) for the model, and
  82. [`BertTokenizerFast.__call__`] to prepare text for the model.
  83. Please refer to the docstring of the above two methods for more information.
  84. """
  85. if images is None and text is None:
  86. raise ValueError("You have to specify at least one of images or text.")
  87. encoding = BatchFeature()
  88. if text is not None:
  89. if isinstance(text, str):
  90. text = [text]
  91. elif not isinstance(text, list) and not isinstance(text[0], str):
  92. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  93. _text_encoding = self.tokenizer(
  94. text=text,
  95. add_special_tokens=add_special_tokens,
  96. padding=padding,
  97. truncation=truncation,
  98. max_length=max_length,
  99. stride=stride,
  100. pad_to_multiple_of=pad_to_multiple_of,
  101. return_attention_mask=return_attention_mask,
  102. return_overflowing_tokens=return_overflowing_tokens,
  103. return_special_tokens_mask=return_special_tokens_mask,
  104. return_offsets_mapping=return_offsets_mapping,
  105. return_token_type_ids=return_token_type_ids,
  106. return_length=return_length,
  107. verbose=verbose,
  108. return_tensors=None, # required to concatenate below
  109. **kwargs,
  110. )
  111. # if we know how many query tokens, expand text inside processor. We need this hacky manipulation
  112. # because BLIP expects image tokens to be at the beginning even before BOS token
  113. if self.num_query_tokens is not None and images is not None:
  114. text_encoding = {}
  115. video_tokens = (
  116. self.video_token.content * self.num_query_tokens * 4
  117. ) # InstrucBLIP works with 4 frames only
  118. video_token_encoding = self.tokenizer(
  119. [video_tokens] * len(text), add_special_tokens=False, return_tensors=None
  120. )
  121. for k in _text_encoding:
  122. text_encoding[k] = [
  123. img_encoding + txt_encoding
  124. for img_encoding, txt_encoding in zip(video_token_encoding[k], _text_encoding[k])
  125. ]
  126. else:
  127. text_encoding = _text_encoding
  128. if images is not None:
  129. logger.warning_once(
  130. "Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
  131. "Please follow instruction here (https://gist.github.com/zucchini-nlp/65f22892b054dc0d68228af56fbeaac2) to update your InstructBLIPVideo model. "
  132. "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
  133. )
  134. # cast to desired return tensors type after concatenating
  135. text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
  136. encoding.update(text_encoding)
  137. qformer_text_encoding = self.qformer_tokenizer(
  138. text=text,
  139. add_special_tokens=add_special_tokens,
  140. padding=padding,
  141. truncation=truncation,
  142. max_length=max_length,
  143. stride=stride,
  144. pad_to_multiple_of=pad_to_multiple_of,
  145. return_attention_mask=return_attention_mask,
  146. return_overflowing_tokens=return_overflowing_tokens,
  147. return_special_tokens_mask=return_special_tokens_mask,
  148. return_offsets_mapping=return_offsets_mapping,
  149. return_token_type_ids=return_token_type_ids,
  150. return_length=return_length,
  151. verbose=verbose,
  152. return_tensors=return_tensors,
  153. **kwargs,
  154. )
  155. encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
  156. encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
  157. if images is not None:
  158. image_encoding = self.image_processor(images, return_tensors=return_tensors)
  159. encoding.update(image_encoding)
  160. return encoding
  161. # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
  162. def batch_decode(self, *args, **kwargs):
  163. """
  164. This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
  165. refer to the docstring of this method for more information.
  166. """
  167. return self.tokenizer.batch_decode(*args, **kwargs)
  168. # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
  169. def decode(self, *args, **kwargs):
  170. """
  171. This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
  172. the docstring of this method for more information.
  173. """
  174. return self.tokenizer.decode(*args, **kwargs)
  175. @property
  176. # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
  177. def model_input_names(self):
  178. tokenizer_input_names = self.tokenizer.model_input_names
  179. image_processor_input_names = self.image_processor.model_input_names
  180. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  181. # overwrite to save the Q-Former tokenizer in a separate folder
  182. def save_pretrained(self, save_directory, **kwargs):
  183. if os.path.isfile(save_directory):
  184. raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
  185. os.makedirs(save_directory, exist_ok=True)
  186. qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
  187. self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path)
  188. # We modify the attributes so that only the tokenizer and image processor are saved in the main folder
  189. qformer_present = "qformer_tokenizer" in self.attributes
  190. if qformer_present:
  191. self.attributes.remove("qformer_tokenizer")
  192. outputs = super().save_pretrained(save_directory, **kwargs)
  193. if qformer_present:
  194. self.attributes += ["qformer_tokenizer"]
  195. return outputs
  196. # overwrite to load the Q-Former tokenizer from a separate folder
  197. @classmethod
  198. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  199. processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
  200. # if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
  201. if isinstance(processor, tuple):
  202. processor = processor[0]
  203. qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
  204. processor.qformer_tokenizer = qformer_tokenizer
  205. return processor