processing_align.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Image/Text processor class for ALIGN
  17. """
  18. from typing import List, Union
  19. from ...image_utils import ImageInput
  20. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
  21. from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
  22. class AlignProcessorKwargs(ProcessingKwargs, total=False):
  23. # see processing_utils.ProcessingKwargs documentation for usage.
  24. _defaults = {
  25. "text_kwargs": {
  26. "padding": "max_length",
  27. "max_length": 64,
  28. },
  29. }
  30. class AlignProcessor(ProcessorMixin):
  31. r"""
  32. Constructs an ALIGN processor which wraps [`EfficientNetImageProcessor`] and
  33. [`BertTokenizer`]/[`BertTokenizerFast`] into a single processor that interits both the image processor and
  34. tokenizer functionalities. See the [`~AlignProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
  35. information.
  36. The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
  37. ```python
  38. from transformers import AlignProcessor
  39. from PIL import Image
  40. model_id = "kakaobrain/align-base"
  41. processor = AlignProcessor.from_pretrained(model_id)
  42. processor(
  43. images=your_pil_image,
  44. text=["What is that?"],
  45. images_kwargs = {"crop_size": {"height": 224, "width": 224}},
  46. text_kwargs = {"padding": "do_not_pad"},
  47. common_kwargs = {"return_tensors": "pt"},
  48. )
  49. ```
  50. Args:
  51. image_processor ([`EfficientNetImageProcessor`]):
  52. The image processor is a required input.
  53. tokenizer ([`BertTokenizer`, `BertTokenizerFast`]):
  54. The tokenizer is a required input.
  55. """
  56. attributes = ["image_processor", "tokenizer"]
  57. image_processor_class = "EfficientNetImageProcessor"
  58. tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
  59. def __init__(self, image_processor, tokenizer):
  60. super().__init__(image_processor, tokenizer)
  61. def __call__(
  62. self,
  63. images: ImageInput = None,
  64. text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
  65. audio=None,
  66. videos=None,
  67. **kwargs: Unpack[AlignProcessorKwargs],
  68. ) -> BatchEncoding:
  69. """
  70. Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
  71. arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
  72. the text. To prepare the image(s), this method forwards the `images` arguments to
  73. EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer
  74. to the doctsring of the above two methods for more information.
  75. Args:
  76. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
  77. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  78. tensor. Both channels-first and channels-last formats are supported.
  79. text (`str`, `List[str]`):
  80. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  81. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  82. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  83. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  84. If set, will return tensors of a particular framework. Acceptable values are:
  85. - `'tf'`: Return TensorFlow `tf.constant` objects.
  86. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  87. - `'np'`: Return NumPy `np.ndarray` objects.
  88. - `'jax'`: Return JAX `jnp.ndarray` objects.
  89. Returns:
  90. [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
  91. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  92. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  93. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  94. `None`).
  95. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  96. """
  97. if text is None and images is None:
  98. raise ValueError("You must specify either text or images.")
  99. # check if images and text inputs are reversed for BC
  100. images, text = _validate_images_text_input_order(images, text)
  101. output_kwargs = self._merge_kwargs(
  102. AlignProcessorKwargs,
  103. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  104. **kwargs,
  105. )
  106. # then, we can pass correct kwargs to each processor
  107. if text is not None:
  108. encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
  109. if images is not None:
  110. image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
  111. # BC for explicit return_tensors
  112. if "return_tensors" in output_kwargs["common_kwargs"]:
  113. return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
  114. if text is not None and images is not None:
  115. encoding["pixel_values"] = image_features.pixel_values
  116. return encoding
  117. elif text is not None:
  118. return encoding
  119. else:
  120. return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
  121. def batch_decode(self, *args, **kwargs):
  122. """
  123. This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  124. refer to the docstring of this method for more information.
  125. """
  126. return self.tokenizer.batch_decode(*args, **kwargs)
  127. def decode(self, *args, **kwargs):
  128. """
  129. This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  130. the docstring of this method for more information.
  131. """
  132. return self.tokenizer.decode(*args, **kwargs)
  133. @property
  134. def model_input_names(self):
  135. tokenizer_input_names = self.tokenizer.model_input_names
  136. image_processor_input_names = self.image_processor.model_input_names
  137. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  138. __all__ = ["AlignProcessor"]