processing_owlv2.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Image/Text processor class for OWLv2
  17. """
  18. from typing import List
  19. import numpy as np
  20. from ...processing_utils import ProcessorMixin
  21. from ...tokenization_utils_base import BatchEncoding
  22. from ...utils import is_flax_available, is_tf_available, is_torch_available
  23. class Owlv2Processor(ProcessorMixin):
  24. r"""
  25. Constructs an Owlv2 processor which wraps [`Owlv2ImageProcessor`] and [`CLIPTokenizer`]/[`CLIPTokenizerFast`] into
  26. a single processor that interits both the image processor and tokenizer functionalities. See the
  27. [`~OwlViTProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information.
  28. Args:
  29. image_processor ([`Owlv2ImageProcessor`]):
  30. The image processor is a required input.
  31. tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
  32. The tokenizer is a required input.
  33. """
  34. attributes = ["image_processor", "tokenizer"]
  35. image_processor_class = "Owlv2ImageProcessor"
  36. tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
  37. def __init__(self, image_processor, tokenizer, **kwargs):
  38. super().__init__(image_processor, tokenizer)
  39. # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.__call__ with OWLViT->OWLv2
  40. def __call__(self, text=None, images=None, query_images=None, padding="max_length", return_tensors="np", **kwargs):
  41. """
  42. Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
  43. `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode:
  44. the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
  45. CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
  46. of the above two methods for more information.
  47. Args:
  48. text (`str`, `List[str]`, `List[List[str]]`):
  49. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  50. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  51. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  52. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
  53. `List[torch.Tensor]`):
  54. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  55. tensor. Both channels-first and channels-last formats are supported.
  56. query_images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
  57. The query image to be prepared, one query image is expected per target image to be queried. Each image
  58. can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image
  59. should be of shape (C, H, W), where C is a number of channels, H and W are image height and width.
  60. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  61. If set, will return tensors of a particular framework. Acceptable values are:
  62. - `'tf'`: Return TensorFlow `tf.constant` objects.
  63. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  64. - `'np'`: Return NumPy `np.ndarray` objects.
  65. - `'jax'`: Return JAX `jnp.ndarray` objects.
  66. Returns:
  67. [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
  68. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  69. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  70. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  71. `None`).
  72. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  73. """
  74. if text is None and query_images is None and images is None:
  75. raise ValueError(
  76. "You have to specify at least one text or query image or image. All three cannot be none."
  77. )
  78. if text is not None:
  79. if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)):
  80. encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)]
  81. elif isinstance(text, List) and isinstance(text[0], List):
  82. encodings = []
  83. # Maximum number of queries across batch
  84. max_num_queries = max([len(t) for t in text])
  85. # Pad all batch samples to max number of text queries
  86. for t in text:
  87. if len(t) != max_num_queries:
  88. t = t + [" "] * (max_num_queries - len(t))
  89. encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs)
  90. encodings.append(encoding)
  91. else:
  92. raise TypeError("Input text should be a string, a list of strings or a nested list of strings")
  93. if return_tensors == "np":
  94. input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
  95. attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
  96. elif return_tensors == "jax" and is_flax_available():
  97. import jax.numpy as jnp
  98. input_ids = jnp.concatenate([encoding["input_ids"] for encoding in encodings], axis=0)
  99. attention_mask = jnp.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0)
  100. elif return_tensors == "pt" and is_torch_available():
  101. import torch
  102. input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0)
  103. attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0)
  104. elif return_tensors == "tf" and is_tf_available():
  105. import tensorflow as tf
  106. input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0)
  107. attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0)
  108. else:
  109. raise ValueError("Target return tensor type could not be returned")
  110. encoding = BatchEncoding()
  111. encoding["input_ids"] = input_ids
  112. encoding["attention_mask"] = attention_mask
  113. if query_images is not None:
  114. encoding = BatchEncoding()
  115. query_pixel_values = self.image_processor(
  116. query_images, return_tensors=return_tensors, **kwargs
  117. ).pixel_values
  118. encoding["query_pixel_values"] = query_pixel_values
  119. if images is not None:
  120. image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
  121. if text is not None and images is not None:
  122. encoding["pixel_values"] = image_features.pixel_values
  123. return encoding
  124. elif query_images is not None and images is not None:
  125. encoding["pixel_values"] = image_features.pixel_values
  126. return encoding
  127. elif text is not None or query_images is not None:
  128. return encoding
  129. else:
  130. return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
  131. # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_object_detection with OWLViT->OWLv2
  132. def post_process_object_detection(self, *args, **kwargs):
  133. """
  134. This method forwards all its arguments to [`OwlViTImageProcessor.post_process_object_detection`]. Please refer
  135. to the docstring of this method for more information.
  136. """
  137. return self.image_processor.post_process_object_detection(*args, **kwargs)
  138. # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.post_process_image_guided_detection with OWLViT->OWLv2
  139. def post_process_image_guided_detection(self, *args, **kwargs):
  140. """
  141. This method forwards all its arguments to [`OwlViTImageProcessor.post_process_one_shot_object_detection`].
  142. Please refer to the docstring of this method for more information.
  143. """
  144. return self.image_processor.post_process_image_guided_detection(*args, **kwargs)
  145. # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.batch_decode
  146. def batch_decode(self, *args, **kwargs):
  147. """
  148. This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  149. refer to the docstring of this method for more information.
  150. """
  151. return self.tokenizer.batch_decode(*args, **kwargs)
  152. # Copied from transformers.models.owlvit.processing_owlvit.OwlViTProcessor.decode
  153. def decode(self, *args, **kwargs):
  154. """
  155. This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  156. the docstring of this method for more information.
  157. """
  158. return self.tokenizer.decode(*args, **kwargs)