processing_pixtral.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for Pixtral.
  17. """
  18. from typing import List, Union
  19. from ...feature_extraction_utils import BatchFeature
  20. from ...image_utils import ImageInput, is_valid_image, load_image
  21. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
  22. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  23. from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
  24. logger = logging.get_logger(__name__)
  25. class PixtralProcessorKwargs(ProcessingKwargs, total=False):
  26. _defaults = {
  27. "text_kwargs": {
  28. "padding": False,
  29. },
  30. "images_kwargs": {},
  31. "common_kwargs": {
  32. "return_tensors": "pt",
  33. },
  34. }
  35. # Copied from transformers.models.idefics2.processing_idefics2.is_url
  36. def is_url(val) -> bool:
  37. return isinstance(val, str) and val.startswith("http")
  38. # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
  39. def is_image_or_image_url(elem):
  40. return is_url(elem) or is_valid_image(elem)
  41. # Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
  42. class BatchMixFeature(BatchFeature):
  43. def to(self, *args, **kwargs) -> "BatchMixFeature":
  44. """
  45. Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
  46. different `dtypes` and sending the `BatchFeature` to a different `device`.
  47. Args:
  48. args (`Tuple`):
  49. Will be passed to the `to(...)` function of the tensors.
  50. kwargs (`Dict`, *optional*):
  51. Will be passed to the `to(...)` function of the tensors.
  52. Returns:
  53. [`BatchFeature`]: The same instance after modification.
  54. """
  55. requires_backends(self, ["torch"])
  56. import torch # noqa
  57. new_data = {}
  58. device = kwargs.get("device")
  59. # Check if the args are a device or a dtype
  60. if device is None and len(args) > 0:
  61. # device should be always the first argument
  62. arg = args[0]
  63. if is_torch_dtype(arg):
  64. # The first argument is a dtype
  65. pass
  66. elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
  67. device = arg
  68. else:
  69. # it's something else
  70. raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
  71. # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  72. for k, v in self.items():
  73. # check if v is a floating point
  74. if isinstance(v, list):
  75. new_data[k] = [
  76. element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
  77. ]
  78. elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
  79. # cast and send to device
  80. new_data[k] = v.to(*args, **kwargs)
  81. elif isinstance(v, torch.Tensor) and device is not None:
  82. new_data[k] = v.to(device=device)
  83. else:
  84. new_data[k] = v
  85. self.data = new_data
  86. return self
  87. class PixtralProcessor(ProcessorMixin):
  88. r"""
  89. Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
  90. [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
  91. [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information.
  92. Args:
  93. image_processor ([`PixtralImageProcessor`], *optional*):
  94. The image processor is a required input.
  95. tokenizer ([`LlamaTokenizerFast`], *optional*):
  96. The tokenizer is a required input.
  97. patch_size (`int`, *optional*, defaults to 16):
  98. Patch size from the vision tower.
  99. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  100. in a chat into a tokenizable string.
  101. image_token (`str`, *optional*, defaults to `"[IMG]"`):
  102. Special token used to denote image location.
  103. image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
  104. Special token used to denote the end of a line of pixels in an image.
  105. image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
  106. Special token used to denote the end of an image input.
  107. """
  108. attributes = ["image_processor", "tokenizer"]
  109. valid_kwargs = [
  110. "chat_template",
  111. "patch_size",
  112. "image_token",
  113. "image_break_token",
  114. "image_end_token",
  115. ]
  116. image_processor_class = "AutoImageProcessor"
  117. tokenizer_class = "AutoTokenizer"
  118. def __init__(
  119. self,
  120. image_processor=None,
  121. tokenizer=None,
  122. patch_size: int = 16,
  123. chat_template=None,
  124. image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
  125. image_break_token="[IMG_BREAK]",
  126. image_end_token="[IMG_END]",
  127. **kwargs,
  128. ):
  129. self.patch_size = patch_size
  130. self.image_token = image_token
  131. self.image_break_token = image_break_token
  132. self.image_end_token = image_end_token
  133. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  134. def __call__(
  135. self,
  136. images: ImageInput = None,
  137. text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
  138. audio=None,
  139. videos=None,
  140. **kwargs: Unpack[PixtralProcessorKwargs],
  141. ) -> BatchMixFeature:
  142. """
  143. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  144. and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
  145. the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
  146. CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
  147. of the above two methods for more information.
  148. Args:
  149. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
  150. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  151. tensor. Both channels-first and channels-last formats are supported.
  152. text (`str`, `List[str]`, `List[List[str]]`):
  153. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  154. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  155. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  156. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  157. If set, will return tensors of a particular framework. Acceptable values are:
  158. - `'tf'`: Return TensorFlow `tf.constant` objects.
  159. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  160. - `'np'`: Return NumPy `np.ndarray` objects.
  161. - `'jax'`: Return JAX `jnp.ndarray` objects.
  162. Returns:
  163. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  164. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  165. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  166. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  167. `None`).
  168. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  169. """
  170. # check if images and text inputs are reversed for BC
  171. images, text = _validate_images_text_input_order(images, text)
  172. output_kwargs = self._merge_kwargs(
  173. PixtralProcessorKwargs,
  174. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  175. **kwargs,
  176. )
  177. if images is not None:
  178. if is_image_or_image_url(images):
  179. images = [[images]]
  180. elif isinstance(images, list) and is_image_or_image_url(images[0]):
  181. if isinstance(text, list):
  182. images = [[im] for im in images]
  183. else:
  184. images = [images]
  185. elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
  186. pass
  187. else:
  188. raise ValueError(
  189. "Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
  190. )
  191. images = [[load_image(im) for im in sample] for sample in images]
  192. image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
  193. else:
  194. image_inputs = {}
  195. if isinstance(text, str):
  196. text = [text]
  197. elif not isinstance(text, list) and not isinstance(text[0], str):
  198. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  199. # try to expand inputs in processing if we have the necessary parts
  200. prompt_strings = text
  201. if image_inputs.get("pixel_values") is not None:
  202. # Replace the image token with the expanded image token sequence
  203. images = image_inputs["pixel_values"]
  204. image_sizes = image_inputs.pop("image_sizes")
  205. prompt_strings = []
  206. for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
  207. replace_strings = []
  208. # First calculate the number of tokens needed for each image and put in a placeholder
  209. for image, image_size in zip(sample_images, sample_image_sizes):
  210. height, width = image_size
  211. num_height_tokens = height // self.patch_size
  212. num_width_tokens = width // self.patch_size
  213. replace_tokens = [
  214. [self.image_token] * num_width_tokens + [self.image_break_token]
  215. ] * num_height_tokens
  216. # Flatten list
  217. replace_tokens = [item for sublist in replace_tokens for item in sublist]
  218. replace_tokens[-1] = self.image_end_token
  219. replace_str = "".join(replace_tokens)
  220. replace_strings.append(replace_str)
  221. sample = sample.replace(self.image_token, "<placeholder>", 1)
  222. while "<placeholder>" in sample:
  223. replace_str = replace_strings.pop(0)
  224. sample = sample.replace("<placeholder>", replace_str, 1)
  225. prompt_strings.append(sample)
  226. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
  227. return BatchMixFeature(data={**text_inputs, **image_inputs})
  228. # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
  229. def batch_decode(self, *args, **kwargs):
  230. """
  231. This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  232. refer to the docstring of this method for more information.
  233. """
  234. return self.tokenizer.batch_decode(*args, **kwargs)
  235. # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
  236. def decode(self, *args, **kwargs):
  237. """
  238. This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  239. the docstring of this method for more information.
  240. """
  241. return self.tokenizer.decode(*args, **kwargs)
  242. @property
  243. # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
  244. def model_input_names(self):
  245. tokenizer_input_names = self.tokenizer.model_input_names
  246. image_processor_input_names = self.image_processor.model_input_names
  247. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))