image_processing_superpoint.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Image processor class for SuperPoint."""
  15. from typing import Dict, Optional, Union
  16. import numpy as np
  17. from ... import is_vision_available
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  19. from ...image_transforms import resize, to_channel_dimension_format
  20. from ...image_utils import (
  21. ChannelDimension,
  22. ImageInput,
  23. infer_channel_dimension_format,
  24. is_scaled_image,
  25. make_list_of_images,
  26. to_numpy_array,
  27. valid_images,
  28. )
  29. from ...utils import TensorType, logging, requires_backends
  30. if is_vision_available():
  31. import PIL
  32. logger = logging.get_logger(__name__)
  33. def is_grayscale(
  34. image: ImageInput,
  35. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  36. ):
  37. if input_data_format == ChannelDimension.FIRST:
  38. return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
  39. elif input_data_format == ChannelDimension.LAST:
  40. return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
  41. def convert_to_grayscale(
  42. image: ImageInput,
  43. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  44. ) -> ImageInput:
  45. """
  46. Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch
  47. and tensorflow grayscale conversion
  48. This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
  49. channel, because of an issue that is discussed in :
  50. https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
  51. Args:
  52. image (Image):
  53. The image to convert.
  54. input_data_format (`ChannelDimension` or `str`, *optional*):
  55. The channel dimension format for the input image.
  56. """
  57. requires_backends(convert_to_grayscale, ["vision"])
  58. if isinstance(image, np.ndarray):
  59. if input_data_format == ChannelDimension.FIRST:
  60. gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
  61. gray_image = np.stack([gray_image] * 3, axis=0)
  62. elif input_data_format == ChannelDimension.LAST:
  63. gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140
  64. gray_image = np.stack([gray_image] * 3, axis=-1)
  65. return gray_image
  66. if not isinstance(image, PIL.Image.Image):
  67. return image
  68. image = image.convert("L")
  69. return image
  70. class SuperPointImageProcessor(BaseImageProcessor):
  71. r"""
  72. Constructs a SuperPoint image processor.
  73. Args:
  74. do_resize (`bool`, *optional*, defaults to `True`):
  75. Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
  76. by `do_resize` in the `preprocess` method.
  77. size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
  78. Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
  79. `True`. Can be overriden by `size` in the `preprocess` method.
  80. do_rescale (`bool`, *optional*, defaults to `True`):
  81. Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
  82. the `preprocess` method.
  83. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  84. Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
  85. method.
  86. """
  87. model_input_names = ["pixel_values"]
  88. def __init__(
  89. self,
  90. do_resize: bool = True,
  91. size: Dict[str, int] = None,
  92. do_rescale: bool = True,
  93. rescale_factor: float = 1 / 255,
  94. **kwargs,
  95. ) -> None:
  96. super().__init__(**kwargs)
  97. size = size if size is not None else {"height": 480, "width": 640}
  98. size = get_size_dict(size, default_to_square=False)
  99. self.do_resize = do_resize
  100. self.size = size
  101. self.do_rescale = do_rescale
  102. self.rescale_factor = rescale_factor
  103. def resize(
  104. self,
  105. image: np.ndarray,
  106. size: Dict[str, int],
  107. data_format: Optional[Union[str, ChannelDimension]] = None,
  108. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  109. **kwargs,
  110. ):
  111. """
  112. Resize an image.
  113. Args:
  114. image (`np.ndarray`):
  115. Image to resize.
  116. size (`Dict[str, int]`):
  117. Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
  118. data_format (`ChannelDimension` or `str`, *optional*):
  119. The channel dimension format of the output image. If not provided, it will be inferred from the input
  120. image. Can be one of:
  121. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  122. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  123. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  124. input_data_format (`ChannelDimension` or `str`, *optional*):
  125. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  126. from the input image. Can be one of:
  127. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  128. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  129. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  130. """
  131. size = get_size_dict(size, default_to_square=False)
  132. return resize(
  133. image,
  134. size=(size["height"], size["width"]),
  135. data_format=data_format,
  136. input_data_format=input_data_format,
  137. **kwargs,
  138. )
  139. def preprocess(
  140. self,
  141. images,
  142. do_resize: bool = None,
  143. size: Dict[str, int] = None,
  144. do_rescale: bool = None,
  145. rescale_factor: float = None,
  146. return_tensors: Optional[Union[str, TensorType]] = None,
  147. data_format: ChannelDimension = ChannelDimension.FIRST,
  148. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  149. **kwargs,
  150. ) -> BatchFeature:
  151. """
  152. Preprocess an image or batch of images.
  153. Args:
  154. images (`ImageInput`):
  155. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  156. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  157. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  158. Whether to resize the image.
  159. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  160. Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
  161. is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
  162. image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
  163. `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
  164. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  165. Whether to rescale the image values between [0 - 1].
  166. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  167. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  168. return_tensors (`str` or `TensorType`, *optional*):
  169. The type of tensors to return. Can be one of:
  170. - Unset: Return a list of `np.ndarray`.
  171. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  172. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  173. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  174. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  175. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  176. The channel dimension format for the output image. Can be one of:
  177. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  178. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  179. - Unset: Use the channel dimension format of the input image.
  180. input_data_format (`ChannelDimension` or `str`, *optional*):
  181. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  182. from the input image. Can be one of:
  183. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  184. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  185. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  186. """
  187. do_resize = do_resize if do_resize is not None else self.do_resize
  188. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  189. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  190. size = size if size is not None else self.size
  191. size = get_size_dict(size, default_to_square=False)
  192. images = make_list_of_images(images)
  193. if not valid_images(images):
  194. raise ValueError(
  195. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  196. "torch.Tensor, tf.Tensor or jax.ndarray."
  197. )
  198. if do_resize and size is None:
  199. raise ValueError("Size must be specified if do_resize is True.")
  200. if do_rescale and rescale_factor is None:
  201. raise ValueError("Rescale factor must be specified if do_rescale is True.")
  202. # All transformations expect numpy arrays.
  203. images = [to_numpy_array(image) for image in images]
  204. if is_scaled_image(images[0]) and do_rescale:
  205. logger.warning_once(
  206. "It looks like you are trying to rescale already rescaled images. If the input"
  207. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  208. )
  209. if input_data_format is None:
  210. # We assume that all images have the same channel dimension format.
  211. input_data_format = infer_channel_dimension_format(images[0])
  212. if do_resize:
  213. images = [self.resize(image=image, size=size, input_data_format=input_data_format) for image in images]
  214. if do_rescale:
  215. images = [
  216. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  217. for image in images
  218. ]
  219. if input_data_format is None:
  220. # We assume that all images have the same channel dimension format.
  221. input_data_format = infer_channel_dimension_format(images[0])
  222. # Checking if image is RGB or grayscale
  223. for i in range(len(images)):
  224. if not is_grayscale(images[i], input_data_format):
  225. images[i] = convert_to_grayscale(images[i], input_data_format=input_data_format)
  226. images = [
  227. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  228. ]
  229. data = {"pixel_values": images}
  230. return BatchFeature(data=data, tensor_type=return_tensors)