| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- # coding=utf-8
- # Copyright 2024 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Processor class for Pixtral.
- """
- from typing import List, Union
- from ...feature_extraction_utils import BatchFeature
- from ...image_utils import ImageInput, is_valid_image, load_image
- from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
- from ...tokenization_utils_base import PreTokenizedInput, TextInput
- from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
- logger = logging.get_logger(__name__)
- class PixtralProcessorKwargs(ProcessingKwargs, total=False):
- _defaults = {
- "text_kwargs": {
- "padding": False,
- },
- "images_kwargs": {},
- "common_kwargs": {
- "return_tensors": "pt",
- },
- }
- # Copied from transformers.models.idefics2.processing_idefics2.is_url
- def is_url(val) -> bool:
- return isinstance(val, str) and val.startswith("http")
- # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
- def is_image_or_image_url(elem):
- return is_url(elem) or is_valid_image(elem)
- # Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
- class BatchMixFeature(BatchFeature):
- def to(self, *args, **kwargs) -> "BatchMixFeature":
- """
- Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
- different `dtypes` and sending the `BatchFeature` to a different `device`.
- Args:
- args (`Tuple`):
- Will be passed to the `to(...)` function of the tensors.
- kwargs (`Dict`, *optional*):
- Will be passed to the `to(...)` function of the tensors.
- Returns:
- [`BatchFeature`]: The same instance after modification.
- """
- requires_backends(self, ["torch"])
- import torch # noqa
- new_data = {}
- device = kwargs.get("device")
- # Check if the args are a device or a dtype
- if device is None and len(args) > 0:
- # device should be always the first argument
- arg = args[0]
- if is_torch_dtype(arg):
- # The first argument is a dtype
- pass
- elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
- device = arg
- else:
- # it's something else
- raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
- # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
- for k, v in self.items():
- # check if v is a floating point
- if isinstance(v, list):
- new_data[k] = [
- element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
- ]
- elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
- # cast and send to device
- new_data[k] = v.to(*args, **kwargs)
- elif isinstance(v, torch.Tensor) and device is not None:
- new_data[k] = v.to(device=device)
- else:
- new_data[k] = v
- self.data = new_data
- return self
- class PixtralProcessor(ProcessorMixin):
- r"""
- Constructs a Pixtral processor which wraps a Pixtral image processor and a Pixtral tokenizer into a single processor.
- [`PixtralProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
- [`~PixtralProcessor.__call__`] and [`~PixtralProcessor.decode`] for more information.
- Args:
- image_processor ([`PixtralImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`LlamaTokenizerFast`], *optional*):
- The tokenizer is a required input.
- patch_size (`int`, *optional*, defaults to 16):
- Patch size from the vision tower.
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- image_token (`str`, *optional*, defaults to `"[IMG]"`):
- Special token used to denote image location.
- image_break_token (`str`, *optional*, defaults to `"[IMG_BREAK]"`):
- Special token used to denote the end of a line of pixels in an image.
- image_end_token (`str`, *optional*, defaults to `"[IMG_END]"`):
- Special token used to denote the end of an image input.
- """
- attributes = ["image_processor", "tokenizer"]
- valid_kwargs = [
- "chat_template",
- "patch_size",
- "image_token",
- "image_break_token",
- "image_end_token",
- ]
- image_processor_class = "AutoImageProcessor"
- tokenizer_class = "AutoTokenizer"
- def __init__(
- self,
- image_processor=None,
- tokenizer=None,
- patch_size: int = 16,
- chat_template=None,
- image_token="[IMG]", # set the default and let users change if they have peculiar special tokens in rare cases
- image_break_token="[IMG_BREAK]",
- image_end_token="[IMG_END]",
- **kwargs,
- ):
- self.patch_size = patch_size
- self.image_token = image_token
- self.image_break_token = image_break_token
- self.image_end_token = image_end_token
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
- def __call__(
- self,
- images: ImageInput = None,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
- audio=None,
- videos=None,
- **kwargs: Unpack[PixtralProcessorKwargs],
- ) -> BatchMixFeature:
- """
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
- and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
- of the above two methods for more information.
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. Both channels-first and channels-last formats are supported.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- # check if images and text inputs are reversed for BC
- images, text = _validate_images_text_input_order(images, text)
- output_kwargs = self._merge_kwargs(
- PixtralProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- if images is not None:
- if is_image_or_image_url(images):
- images = [[images]]
- elif isinstance(images, list) and is_image_or_image_url(images[0]):
- if isinstance(text, list):
- images = [[im] for im in images]
- else:
- images = [images]
- elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
- pass
- else:
- raise ValueError(
- "Invalid input images. Please provide a single image, a list of images, or a list of lists of images."
- )
- images = [[load_image(im) for im in sample] for sample in images]
- image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
- else:
- image_inputs = {}
- if isinstance(text, str):
- text = [text]
- elif not isinstance(text, list) and not isinstance(text[0], str):
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
- # try to expand inputs in processing if we have the necessary parts
- prompt_strings = text
- if image_inputs.get("pixel_values") is not None:
- # Replace the image token with the expanded image token sequence
- images = image_inputs["pixel_values"]
- image_sizes = image_inputs.pop("image_sizes")
- prompt_strings = []
- for sample_images, sample_image_sizes, sample in zip(images, image_sizes, text):
- replace_strings = []
- # First calculate the number of tokens needed for each image and put in a placeholder
- for image, image_size in zip(sample_images, sample_image_sizes):
- height, width = image_size
- num_height_tokens = height // self.patch_size
- num_width_tokens = width // self.patch_size
- replace_tokens = [
- [self.image_token] * num_width_tokens + [self.image_break_token]
- ] * num_height_tokens
- # Flatten list
- replace_tokens = [item for sublist in replace_tokens for item in sublist]
- replace_tokens[-1] = self.image_end_token
- replace_str = "".join(replace_tokens)
- replace_strings.append(replace_str)
- sample = sample.replace(self.image_token, "<placeholder>", 1)
- while "<placeholder>" in sample:
- replace_str = replace_strings.pop(0)
- sample = sample.replace("<placeholder>", replace_str, 1)
- prompt_strings.append(sample)
- text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
- return BatchMixFeature(data={**text_inputs, **image_inputs})
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
- @property
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|