image_processing_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Dict, Iterable, Optional, Union
  16. import numpy as np
  17. from .image_processing_base import BatchFeature, ImageProcessingMixin
  18. from .image_transforms import center_crop, normalize, rescale
  19. from .image_utils import ChannelDimension
  20. from .utils import logging
  21. logger = logging.get_logger(__name__)
  22. INIT_SERVICE_KWARGS = [
  23. "processor_class",
  24. "image_processor_type",
  25. ]
  26. class BaseImageProcessor(ImageProcessingMixin):
  27. def __init__(self, **kwargs):
  28. super().__init__(**kwargs)
  29. def __call__(self, images, **kwargs) -> BatchFeature:
  30. """Preprocess an image or a batch of images."""
  31. return self.preprocess(images, **kwargs)
  32. def preprocess(self, images, **kwargs) -> BatchFeature:
  33. raise NotImplementedError("Each image processor must implement its own preprocess method")
  34. def rescale(
  35. self,
  36. image: np.ndarray,
  37. scale: float,
  38. data_format: Optional[Union[str, ChannelDimension]] = None,
  39. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  40. **kwargs,
  41. ) -> np.ndarray:
  42. """
  43. Rescale an image by a scale factor. image = image * scale.
  44. Args:
  45. image (`np.ndarray`):
  46. Image to rescale.
  47. scale (`float`):
  48. The scaling factor to rescale pixel values by.
  49. data_format (`str` or `ChannelDimension`, *optional*):
  50. The channel dimension format for the output image. If unset, the channel dimension format of the input
  51. image is used. Can be one of:
  52. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  53. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  54. input_data_format (`ChannelDimension` or `str`, *optional*):
  55. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  56. from the input image. Can be one of:
  57. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  58. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  59. Returns:
  60. `np.ndarray`: The rescaled image.
  61. """
  62. return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
  63. def normalize(
  64. self,
  65. image: np.ndarray,
  66. mean: Union[float, Iterable[float]],
  67. std: Union[float, Iterable[float]],
  68. data_format: Optional[Union[str, ChannelDimension]] = None,
  69. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  70. **kwargs,
  71. ) -> np.ndarray:
  72. """
  73. Normalize an image. image = (image - image_mean) / image_std.
  74. Args:
  75. image (`np.ndarray`):
  76. Image to normalize.
  77. mean (`float` or `Iterable[float]`):
  78. Image mean to use for normalization.
  79. std (`float` or `Iterable[float]`):
  80. Image standard deviation to use for normalization.
  81. data_format (`str` or `ChannelDimension`, *optional*):
  82. The channel dimension format for the output image. If unset, the channel dimension format of the input
  83. image is used. Can be one of:
  84. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  85. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  86. input_data_format (`ChannelDimension` or `str`, *optional*):
  87. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  88. from the input image. Can be one of:
  89. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  90. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  91. Returns:
  92. `np.ndarray`: The normalized image.
  93. """
  94. return normalize(
  95. image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
  96. )
  97. def center_crop(
  98. self,
  99. image: np.ndarray,
  100. size: Dict[str, int],
  101. data_format: Optional[Union[str, ChannelDimension]] = None,
  102. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  103. **kwargs,
  104. ) -> np.ndarray:
  105. """
  106. Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
  107. any edge, the image is padded with 0's and then center cropped.
  108. Args:
  109. image (`np.ndarray`):
  110. Image to center crop.
  111. size (`Dict[str, int]`):
  112. Size of the output image.
  113. data_format (`str` or `ChannelDimension`, *optional*):
  114. The channel dimension format for the output image. If unset, the channel dimension format of the input
  115. image is used. Can be one of:
  116. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  117. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  118. input_data_format (`ChannelDimension` or `str`, *optional*):
  119. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  120. from the input image. Can be one of:
  121. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  122. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  123. """
  124. size = get_size_dict(size)
  125. if "height" not in size or "width" not in size:
  126. raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
  127. return center_crop(
  128. image,
  129. size=(size["height"], size["width"]),
  130. data_format=data_format,
  131. input_data_format=input_data_format,
  132. **kwargs,
  133. )
  134. def to_dict(self):
  135. encoder_dict = super().to_dict()
  136. encoder_dict.pop("_valid_processor_keys", None)
  137. return encoder_dict
  138. VALID_SIZE_DICT_KEYS = (
  139. {"height", "width"},
  140. {"shortest_edge"},
  141. {"shortest_edge", "longest_edge"},
  142. {"longest_edge"},
  143. {"max_height", "max_width"},
  144. )
  145. def is_valid_size_dict(size_dict):
  146. if not isinstance(size_dict, dict):
  147. return False
  148. size_dict_keys = set(size_dict.keys())
  149. for allowed_keys in VALID_SIZE_DICT_KEYS:
  150. if size_dict_keys == allowed_keys:
  151. return True
  152. return False
  153. def convert_to_size_dict(
  154. size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
  155. ):
  156. # By default, if size is an int we assume it represents a tuple of (size, size).
  157. if isinstance(size, int) and default_to_square:
  158. if max_size is not None:
  159. raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
  160. return {"height": size, "width": size}
  161. # In other configs, if size is an int and default_to_square is False, size represents the length of
  162. # the shortest edge after resizing.
  163. elif isinstance(size, int) and not default_to_square:
  164. size_dict = {"shortest_edge": size}
  165. if max_size is not None:
  166. size_dict["longest_edge"] = max_size
  167. return size_dict
  168. # Otherwise, if size is a tuple it's either (height, width) or (width, height)
  169. elif isinstance(size, (tuple, list)) and height_width_order:
  170. return {"height": size[0], "width": size[1]}
  171. elif isinstance(size, (tuple, list)) and not height_width_order:
  172. return {"height": size[1], "width": size[0]}
  173. elif size is None and max_size is not None:
  174. if default_to_square:
  175. raise ValueError("Cannot specify both default_to_square=True and max_size")
  176. return {"longest_edge": max_size}
  177. raise ValueError(f"Could not convert size input to size dict: {size}")
  178. def get_size_dict(
  179. size: Union[int, Iterable[int], Dict[str, int]] = None,
  180. max_size: Optional[int] = None,
  181. height_width_order: bool = True,
  182. default_to_square: bool = True,
  183. param_name="size",
  184. ) -> dict:
  185. """
  186. Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
  187. compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
  188. width) or (width, height) format.
  189. - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
  190. size[0]}` if `height_width_order` is `False`.
  191. - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
  192. - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
  193. is set, it is added to the dict as `{"longest_edge": max_size}`.
  194. Args:
  195. size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
  196. The `size` parameter to be cast into a size dictionary.
  197. max_size (`Optional[int]`, *optional*):
  198. The `max_size` parameter to be cast into a size dictionary.
  199. height_width_order (`bool`, *optional*, defaults to `True`):
  200. If `size` is a tuple, whether it's in (height, width) or (width, height) order.
  201. default_to_square (`bool`, *optional*, defaults to `True`):
  202. If `size` is an int, whether to default to a square image or not.
  203. """
  204. if not isinstance(size, dict):
  205. size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
  206. logger.info(
  207. f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
  208. f" Converted to {size_dict}.",
  209. )
  210. else:
  211. size_dict = size
  212. if not is_valid_size_dict(size_dict):
  213. raise ValueError(
  214. f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
  215. )
  216. return size_dict
  217. def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
  218. """
  219. Selects the best resolution from a list of possible resolutions based on the original size.
  220. This is done by calculating the effective and wasted resolution for each possible resolution.
  221. The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
  222. Args:
  223. original_size (tuple):
  224. The original size of the image in the format (height, width).
  225. possible_resolutions (list):
  226. A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
  227. Returns:
  228. tuple: The best fit resolution in the format (height, width).
  229. """
  230. original_height, original_width = original_size
  231. best_fit = None
  232. max_effective_resolution = 0
  233. min_wasted_resolution = float("inf")
  234. for height, width in possible_resolutions:
  235. scale = min(width / original_width, height / original_height)
  236. downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
  237. effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
  238. wasted_resolution = (width * height) - effective_resolution
  239. if effective_resolution > max_effective_resolution or (
  240. effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
  241. ):
  242. max_effective_resolution = effective_resolution
  243. min_wasted_resolution = wasted_resolution
  244. best_fit = (height, width)
  245. return best_fit