| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- # coding=utf-8
- # Copyright 2023 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Processor class for Llava.
- """
- from typing import List, Union
- from ...feature_extraction_utils import BatchFeature
- from ...image_utils import ImageInput, get_image_size, to_numpy_array
- from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
- from ...tokenization_utils_base import PreTokenizedInput, TextInput
- from ...utils import logging
- logger = logging.get_logger(__name__)
- class LlavaProcessorKwargs(ProcessingKwargs, total=False):
- _defaults = {
- "text_kwargs": {
- "padding": False,
- },
- "images_kwargs": {},
- }
- class LlavaProcessor(ProcessorMixin):
- r"""
- Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
- [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
- [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
- Args:
- image_processor ([`CLIPImageProcessor`], *optional*):
- The image processor is a required input.
- tokenizer ([`LlamaTokenizerFast`], *optional*):
- The tokenizer is a required input.
- patch_size (`int`, *optional*):
- Patch size from the vision tower.
- vision_feature_select_strategy (`str`, *optional*):
- The feature selection strategy used to select the vision feature from the vision backbone.
- Shoudl be same as in model's config
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
- in a chat into a tokenizable string.
- image_token (`str`, *optional*, defaults to `"<image>"`):
- Special token used to denote image location.
- """
- attributes = ["image_processor", "tokenizer"]
- valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"]
- image_processor_class = "AutoImageProcessor"
- tokenizer_class = "AutoTokenizer"
- def __init__(
- self,
- image_processor=None,
- tokenizer=None,
- patch_size=None,
- vision_feature_select_strategy=None,
- chat_template=None,
- image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
- **kwargs,
- ):
- self.patch_size = patch_size
- self.vision_feature_select_strategy = vision_feature_select_strategy
- self.image_token = image_token
- super().__init__(image_processor, tokenizer, chat_template=chat_template)
- def __call__(
- self,
- images: ImageInput = None,
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
- audio=None,
- videos=None,
- **kwargs: Unpack[LlavaProcessorKwargs],
- ) -> BatchFeature:
- """
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
- and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
- CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
- of the above two methods for more information.
- Args:
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
- tensor. Both channels-first and channels-last formats are supported.
- text (`str`, `List[str]`, `List[List[str]]`):
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors of a particular framework. Acceptable values are:
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return NumPy `np.ndarray` objects.
- - `'jax'`: Return JAX `jnp.ndarray` objects.
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
- `None`).
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- """
- if images is None and text is None:
- raise ValueError("You have to specify at least one of `images` or `text`.")
- # check if images and text inputs are reversed for BC
- images, text = _validate_images_text_input_order(images, text)
- output_kwargs = self._merge_kwargs(
- LlavaProcessorKwargs,
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
- **kwargs,
- )
- if images is not None:
- image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
- else:
- image_inputs = {}
- if isinstance(text, str):
- text = [text]
- elif not isinstance(text, list) and not isinstance(text[0], str):
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
- # try to expand inputs in processing if we have the necessary parts
- prompt_strings = text
- if image_inputs.get("pixel_values") is not None:
- if self.patch_size is not None and self.vision_feature_select_strategy is not None:
- # Replace the image token with the expanded image token sequence
- pixel_values = image_inputs["pixel_values"]
- height, width = get_image_size(to_numpy_array(pixel_values[0]))
- num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
- if self.vision_feature_select_strategy == "default":
- num_image_tokens -= 1
- prompt_strings = []
- for sample in text:
- sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
- prompt_strings.append(sample)
- else:
- logger.warning_once(
- "Expanding inputs for image tokens in LLaVa should be done in processing. "
- "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
- "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
- "Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
- )
- text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
- return BatchFeature(data={**text_inputs, **image_inputs})
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
- def batch_decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
- refer to the docstring of this method for more information.
- """
- return self.tokenizer.batch_decode(*args, **kwargs)
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
- def decode(self, *args, **kwargs):
- """
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
- the docstring of this method for more information.
- """
- return self.tokenizer.decode(*args, **kwargs)
- @property
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
- def model_input_names(self):
- tokenizer_input_names = self.tokenizer.model_input_names
- image_processor_input_names = self.image_processor.model_input_names
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|