image_processing_donut.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Donut."""
  16. from typing import Dict, List, Optional, Union
  17. import numpy as np
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  19. from ...image_transforms import (
  20. get_resize_output_image_size,
  21. pad,
  22. resize,
  23. to_channel_dimension_format,
  24. )
  25. from ...image_utils import (
  26. IMAGENET_STANDARD_MEAN,
  27. IMAGENET_STANDARD_STD,
  28. ChannelDimension,
  29. ImageInput,
  30. PILImageResampling,
  31. get_image_size,
  32. infer_channel_dimension_format,
  33. is_scaled_image,
  34. make_list_of_images,
  35. to_numpy_array,
  36. valid_images,
  37. validate_preprocess_arguments,
  38. )
  39. from ...utils import TensorType, filter_out_non_signature_kwargs, logging
  40. from ...utils.import_utils import is_vision_available
  41. logger = logging.get_logger(__name__)
  42. if is_vision_available():
  43. import PIL
  44. class DonutImageProcessor(BaseImageProcessor):
  45. r"""
  46. Constructs a Donut image processor.
  47. Args:
  48. do_resize (`bool`, *optional*, defaults to `True`):
  49. Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
  50. `do_resize` in the `preprocess` method.
  51. size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
  52. Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
  53. the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
  54. method.
  55. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
  56. Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
  57. do_thumbnail (`bool`, *optional*, defaults to `True`):
  58. Whether to resize the image using thumbnail method.
  59. do_align_long_axis (`bool`, *optional*, defaults to `False`):
  60. Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
  61. do_pad (`bool`, *optional*, defaults to `True`):
  62. Whether to pad the image. If `random_padding` is set to `True` in `preprocess`, each image is padded with a
  63. random amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
  64. padded to the largest image size in the batch.
  65. do_rescale (`bool`, *optional*, defaults to `True`):
  66. Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
  67. the `preprocess` method.
  68. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  69. Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
  70. method.
  71. do_normalize (`bool`, *optional*, defaults to `True`):
  72. Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
  73. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
  74. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  75. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  76. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
  77. Image standard deviation.
  78. """
  79. model_input_names = ["pixel_values"]
  80. def __init__(
  81. self,
  82. do_resize: bool = True,
  83. size: Dict[str, int] = None,
  84. resample: PILImageResampling = PILImageResampling.BILINEAR,
  85. do_thumbnail: bool = True,
  86. do_align_long_axis: bool = False,
  87. do_pad: bool = True,
  88. do_rescale: bool = True,
  89. rescale_factor: Union[int, float] = 1 / 255,
  90. do_normalize: bool = True,
  91. image_mean: Optional[Union[float, List[float]]] = None,
  92. image_std: Optional[Union[float, List[float]]] = None,
  93. **kwargs,
  94. ) -> None:
  95. super().__init__(**kwargs)
  96. size = size if size is not None else {"height": 2560, "width": 1920}
  97. if isinstance(size, (tuple, list)):
  98. # The previous feature extractor size parameter was in (width, height) format
  99. size = size[::-1]
  100. size = get_size_dict(size)
  101. self.do_resize = do_resize
  102. self.size = size
  103. self.resample = resample
  104. self.do_thumbnail = do_thumbnail
  105. self.do_align_long_axis = do_align_long_axis
  106. self.do_pad = do_pad
  107. self.do_rescale = do_rescale
  108. self.rescale_factor = rescale_factor
  109. self.do_normalize = do_normalize
  110. self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
  111. self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
  112. def align_long_axis(
  113. self,
  114. image: np.ndarray,
  115. size: Dict[str, int],
  116. data_format: Optional[Union[str, ChannelDimension]] = None,
  117. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  118. ) -> np.ndarray:
  119. """
  120. Align the long axis of the image to the longest axis of the specified size.
  121. Args:
  122. image (`np.ndarray`):
  123. The image to be aligned.
  124. size (`Dict[str, int]`):
  125. The size `{"height": h, "width": w}` to align the long axis to.
  126. data_format (`str` or `ChannelDimension`, *optional*):
  127. The data format of the output image. If unset, the same format as the input image is used.
  128. input_data_format (`ChannelDimension` or `str`, *optional*):
  129. The channel dimension format of the input image. If not provided, it will be inferred.
  130. Returns:
  131. `np.ndarray`: The aligned image.
  132. """
  133. input_height, input_width = get_image_size(image, channel_dim=input_data_format)
  134. output_height, output_width = size["height"], size["width"]
  135. if (output_width < output_height and input_width > input_height) or (
  136. output_width > output_height and input_width < input_height
  137. ):
  138. image = np.rot90(image, 3)
  139. if data_format is not None:
  140. image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
  141. return image
  142. def pad_image(
  143. self,
  144. image: np.ndarray,
  145. size: Dict[str, int],
  146. random_padding: bool = False,
  147. data_format: Optional[Union[str, ChannelDimension]] = None,
  148. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  149. ) -> np.ndarray:
  150. """
  151. Pad the image to the specified size.
  152. Args:
  153. image (`np.ndarray`):
  154. The image to be padded.
  155. size (`Dict[str, int]`):
  156. The size `{"height": h, "width": w}` to pad the image to.
  157. random_padding (`bool`, *optional*, defaults to `False`):
  158. Whether to use random padding or not.
  159. data_format (`str` or `ChannelDimension`, *optional*):
  160. The data format of the output image. If unset, the same format as the input image is used.
  161. input_data_format (`ChannelDimension` or `str`, *optional*):
  162. The channel dimension format of the input image. If not provided, it will be inferred.
  163. """
  164. output_height, output_width = size["height"], size["width"]
  165. input_height, input_width = get_image_size(image, channel_dim=input_data_format)
  166. delta_width = output_width - input_width
  167. delta_height = output_height - input_height
  168. if random_padding:
  169. pad_top = np.random.randint(low=0, high=delta_height + 1)
  170. pad_left = np.random.randint(low=0, high=delta_width + 1)
  171. else:
  172. pad_top = delta_height // 2
  173. pad_left = delta_width // 2
  174. pad_bottom = delta_height - pad_top
  175. pad_right = delta_width - pad_left
  176. padding = ((pad_top, pad_bottom), (pad_left, pad_right))
  177. return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
  178. def pad(self, *args, **kwargs):
  179. logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
  180. return self.pad_image(*args, **kwargs)
  181. def thumbnail(
  182. self,
  183. image: np.ndarray,
  184. size: Dict[str, int],
  185. resample: PILImageResampling = PILImageResampling.BICUBIC,
  186. data_format: Optional[Union[str, ChannelDimension]] = None,
  187. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  188. **kwargs,
  189. ) -> np.ndarray:
  190. """
  191. Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
  192. corresponding dimension of the specified size.
  193. Args:
  194. image (`np.ndarray`):
  195. The image to be resized.
  196. size (`Dict[str, int]`):
  197. The size `{"height": h, "width": w}` to resize the image to.
  198. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  199. The resampling filter to use.
  200. data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
  201. The data format of the output image. If unset, the same format as the input image is used.
  202. input_data_format (`ChannelDimension` or `str`, *optional*):
  203. The channel dimension format of the input image. If not provided, it will be inferred.
  204. """
  205. input_height, input_width = get_image_size(image, channel_dim=input_data_format)
  206. output_height, output_width = size["height"], size["width"]
  207. # We always resize to the smallest of either the input or output size.
  208. height = min(input_height, output_height)
  209. width = min(input_width, output_width)
  210. if height == input_height and width == input_width:
  211. return image
  212. if input_height > input_width:
  213. width = int(input_width * height / input_height)
  214. elif input_width > input_height:
  215. height = int(input_height * width / input_width)
  216. return resize(
  217. image,
  218. size=(height, width),
  219. resample=resample,
  220. reducing_gap=2.0,
  221. data_format=data_format,
  222. input_data_format=input_data_format,
  223. **kwargs,
  224. )
  225. def resize(
  226. self,
  227. image: np.ndarray,
  228. size: Dict[str, int],
  229. resample: PILImageResampling = PILImageResampling.BICUBIC,
  230. data_format: Optional[Union[str, ChannelDimension]] = None,
  231. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  232. **kwargs,
  233. ) -> np.ndarray:
  234. """
  235. Resizes `image` to `(height, width)` specified by `size` using the PIL library.
  236. Args:
  237. image (`np.ndarray`):
  238. Image to resize.
  239. size (`Dict[str, int]`):
  240. Size of the output image.
  241. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  242. Resampling filter to use when resiizing the image.
  243. data_format (`str` or `ChannelDimension`, *optional*):
  244. The channel dimension format of the image. If not provided, it will be the same as the input image.
  245. input_data_format (`ChannelDimension` or `str`, *optional*):
  246. The channel dimension format of the input image. If not provided, it will be inferred.
  247. """
  248. size = get_size_dict(size)
  249. shortest_edge = min(size["height"], size["width"])
  250. output_size = get_resize_output_image_size(
  251. image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
  252. )
  253. resized_image = resize(
  254. image,
  255. size=output_size,
  256. resample=resample,
  257. data_format=data_format,
  258. input_data_format=input_data_format,
  259. **kwargs,
  260. )
  261. return resized_image
  262. @filter_out_non_signature_kwargs()
  263. def preprocess(
  264. self,
  265. images: ImageInput,
  266. do_resize: bool = None,
  267. size: Dict[str, int] = None,
  268. resample: PILImageResampling = None,
  269. do_thumbnail: bool = None,
  270. do_align_long_axis: bool = None,
  271. do_pad: bool = None,
  272. random_padding: bool = False,
  273. do_rescale: bool = None,
  274. rescale_factor: float = None,
  275. do_normalize: bool = None,
  276. image_mean: Optional[Union[float, List[float]]] = None,
  277. image_std: Optional[Union[float, List[float]]] = None,
  278. return_tensors: Optional[Union[str, TensorType]] = None,
  279. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  280. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  281. ) -> PIL.Image.Image:
  282. """
  283. Preprocess an image or batch of images.
  284. Args:
  285. images (`ImageInput`):
  286. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  287. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  288. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  289. Whether to resize the image.
  290. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  291. Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
  292. size["width"]) with the longest edge resized to keep the input aspect ratio.
  293. resample (`int`, *optional*, defaults to `self.resample`):
  294. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  295. has an effect if `do_resize` is set to `True`.
  296. do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
  297. Whether to resize the image using thumbnail method.
  298. do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
  299. Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
  300. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  301. Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
  302. amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
  303. padded to the largest image size in the batch.
  304. random_padding (`bool`, *optional*, defaults to `self.random_padding`):
  305. Whether to use random padding when padding the image. If `True`, each image in the batch with be padded
  306. with a random amount of padding on each side up to the size of the largest image in the batch.
  307. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  308. Whether to rescale the image pixel values.
  309. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  310. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  311. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  312. Whether to normalize the image.
  313. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  314. Image mean to use for normalization.
  315. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  316. Image standard deviation to use for normalization.
  317. return_tensors (`str` or `TensorType`, *optional*):
  318. The type of tensors to return. Can be one of:
  319. - Unset: Return a list of `np.ndarray`.
  320. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  321. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  322. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  323. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  324. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  325. The channel dimension format for the output image. Can be one of:
  326. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  327. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  328. - Unset: defaults to the channel dimension format of the input image.
  329. input_data_format (`ChannelDimension` or `str`, *optional*):
  330. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  331. from the input image. Can be one of:
  332. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  333. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  334. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  335. """
  336. do_resize = do_resize if do_resize is not None else self.do_resize
  337. size = size if size is not None else self.size
  338. if isinstance(size, (tuple, list)):
  339. # Previous feature extractor had size in (width, height) format
  340. size = size[::-1]
  341. size = get_size_dict(size)
  342. resample = resample if resample is not None else self.resample
  343. do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
  344. do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
  345. do_pad = do_pad if do_pad is not None else self.do_pad
  346. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  347. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  348. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  349. image_mean = image_mean if image_mean is not None else self.image_mean
  350. image_std = image_std if image_std is not None else self.image_std
  351. images = make_list_of_images(images)
  352. if not valid_images(images):
  353. raise ValueError(
  354. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  355. "torch.Tensor, tf.Tensor or jax.ndarray."
  356. )
  357. validate_preprocess_arguments(
  358. do_rescale=do_rescale,
  359. rescale_factor=rescale_factor,
  360. do_normalize=do_normalize,
  361. image_mean=image_mean,
  362. image_std=image_std,
  363. do_pad=do_pad,
  364. size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
  365. do_resize=do_resize,
  366. size=size,
  367. resample=resample,
  368. )
  369. # All transformations expect numpy arrays.
  370. images = [to_numpy_array(image) for image in images]
  371. if is_scaled_image(images[0]) and do_rescale:
  372. logger.warning_once(
  373. "It looks like you are trying to rescale already rescaled images. If the input"
  374. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  375. )
  376. if input_data_format is None:
  377. # We assume that all images have the same channel dimension format.
  378. input_data_format = infer_channel_dimension_format(images[0])
  379. if do_align_long_axis:
  380. images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
  381. if do_resize:
  382. images = [
  383. self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
  384. for image in images
  385. ]
  386. if do_thumbnail:
  387. images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
  388. if do_pad:
  389. images = [
  390. self.pad_image(
  391. image=image, size=size, random_padding=random_padding, input_data_format=input_data_format
  392. )
  393. for image in images
  394. ]
  395. if do_rescale:
  396. images = [
  397. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  398. for image in images
  399. ]
  400. if do_normalize:
  401. images = [
  402. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  403. for image in images
  404. ]
  405. images = [
  406. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  407. ]
  408. data = {"pixel_values": images}
  409. return BatchFeature(data=data, tensor_type=return_tensors)