image_processing_pix2struct.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Pix2Struct."""
  16. import io
  17. import math
  18. from typing import Dict, Optional, Union
  19. import numpy as np
  20. from huggingface_hub import hf_hub_download
  21. from ...image_processing_utils import BaseImageProcessor, BatchFeature
  22. from ...image_transforms import convert_to_rgb, normalize, to_channel_dimension_format, to_pil_image
  23. from ...image_utils import (
  24. ChannelDimension,
  25. ImageInput,
  26. get_image_size,
  27. infer_channel_dimension_format,
  28. make_list_of_images,
  29. to_numpy_array,
  30. valid_images,
  31. )
  32. from ...utils import TensorType, is_torch_available, is_vision_available, logging
  33. from ...utils.import_utils import requires_backends
  34. if is_vision_available():
  35. import textwrap
  36. from PIL import Image, ImageDraw, ImageFont
  37. if is_torch_available():
  38. import torch
  39. logger = logging.get_logger(__name__)
  40. DEFAULT_FONT_PATH = "ybelkada/fonts"
  41. # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
  42. def torch_extract_patches(image_tensor, patch_height, patch_width):
  43. """
  44. Utiliy function to extract patches from a given image tensor. Returns a tensor of shape (1, `patch_height`,
  45. `patch_width`, `num_channels`x `patch_height` x `patch_width`)
  46. Args:
  47. image_tensor (torch.Tensor):
  48. The image tensor to extract patches from.
  49. patch_height (int):
  50. The height of the patches to extract.
  51. patch_width (int):
  52. The width of the patches to extract.
  53. """
  54. requires_backends(torch_extract_patches, ["torch"])
  55. image_tensor = image_tensor.unsqueeze(0)
  56. patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
  57. patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1)
  58. patches = patches.permute(0, 4, 2, 3, 1).reshape(
  59. image_tensor.size(2) // patch_height,
  60. image_tensor.size(3) // patch_width,
  61. image_tensor.size(1) * patch_height * patch_width,
  62. )
  63. return patches.unsqueeze(0)
  64. # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L106
  65. def render_text(
  66. text: str,
  67. text_size: int = 36,
  68. text_color: str = "black",
  69. background_color: str = "white",
  70. left_padding: int = 5,
  71. right_padding: int = 5,
  72. top_padding: int = 5,
  73. bottom_padding: int = 5,
  74. font_bytes: Optional[bytes] = None,
  75. font_path: Optional[str] = None,
  76. ) -> Image.Image:
  77. """
  78. Render text. This script is entirely adapted from the original script that can be found here:
  79. https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
  80. Args:
  81. text (`str`, *optional*, defaults to ):
  82. Text to render.
  83. text_size (`int`, *optional*, defaults to 36):
  84. Size of the text.
  85. text_color (`str`, *optional*, defaults to `"black"`):
  86. Color of the text.
  87. background_color (`str`, *optional*, defaults to `"white"`):
  88. Color of the background.
  89. left_padding (`int`, *optional*, defaults to 5):
  90. Padding on the left.
  91. right_padding (`int`, *optional*, defaults to 5):
  92. Padding on the right.
  93. top_padding (`int`, *optional*, defaults to 5):
  94. Padding on the top.
  95. bottom_padding (`int`, *optional*, defaults to 5):
  96. Padding on the bottom.
  97. font_bytes (`bytes`, *optional*):
  98. Bytes of the font to use. If `None`, the default font will be used.
  99. font_path (`str`, *optional*):
  100. Path to the font to use. If `None`, the default font will be used.
  101. """
  102. requires_backends(render_text, "vision")
  103. # Add new lines so that each line is no more than 80 characters.
  104. wrapper = textwrap.TextWrapper(width=80)
  105. lines = wrapper.wrap(text=text)
  106. wrapped_text = "\n".join(lines)
  107. if font_bytes is not None and font_path is None:
  108. font = io.BytesIO(font_bytes)
  109. elif font_path is not None:
  110. font = font_path
  111. else:
  112. font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
  113. font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
  114. # Use a temporary canvas to determine the width and height in pixels when
  115. # rendering the text.
  116. temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
  117. _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
  118. # Create the actual image with a bit of padding around the text.
  119. image_width = text_width + left_padding + right_padding
  120. image_height = text_height + top_padding + bottom_padding
  121. image = Image.new("RGB", (image_width, image_height), background_color)
  122. draw = ImageDraw.Draw(image)
  123. draw.text(xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font)
  124. return image
  125. # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
  126. def render_header(
  127. image: np.ndarray, header: str, input_data_format: Optional[Union[str, ChildProcessError]] = None, **kwargs
  128. ):
  129. """
  130. Renders the input text as a header on the input image.
  131. Args:
  132. image (`np.ndarray`):
  133. The image to render the header on.
  134. header (`str`):
  135. The header text.
  136. data_format (`Union[ChannelDimension, str]`, *optional*):
  137. The data format of the image. Can be either "ChannelDimension.channels_first" or
  138. "ChannelDimension.channels_last".
  139. Returns:
  140. `np.ndarray`: The image with the header rendered.
  141. """
  142. requires_backends(render_header, "vision")
  143. # Convert to PIL image if necessary
  144. image = to_pil_image(image, input_data_format=input_data_format)
  145. header_image = render_text(header, **kwargs)
  146. new_width = max(header_image.width, image.width)
  147. new_height = int(image.height * (new_width / image.width))
  148. new_header_height = int(header_image.height * (new_width / header_image.width))
  149. new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
  150. new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
  151. new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
  152. # Convert back to the original framework if necessary
  153. new_image = to_numpy_array(new_image)
  154. if infer_channel_dimension_format(new_image) == ChannelDimension.LAST:
  155. new_image = to_channel_dimension_format(new_image, ChannelDimension.LAST)
  156. return new_image
  157. class Pix2StructImageProcessor(BaseImageProcessor):
  158. r"""
  159. Constructs a Pix2Struct image processor.
  160. Args:
  161. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  162. Whether to convert the image to RGB.
  163. do_normalize (`bool`, *optional*, defaults to `True`):
  164. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
  165. method. According to Pix2Struct paper and code, the image is normalized with its own mean and standard
  166. deviation.
  167. patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 16, "width": 16}`):
  168. The patch size to use for the image. According to Pix2Struct paper and code, the patch size is 16x16.
  169. max_patches (`int`, *optional*, defaults to 2048):
  170. The maximum number of patches to extract from the image as per the [Pix2Struct
  171. paper](https://arxiv.org/pdf/2210.03347.pdf).
  172. is_vqa (`bool`, *optional*, defaults to `False`):
  173. Whether or not the image processor is for the VQA task. If `True` and `header_text` is passed in, text is
  174. rendered onto the input images.
  175. """
  176. model_input_names = ["flattened_patches"]
  177. def __init__(
  178. self,
  179. do_convert_rgb: bool = True,
  180. do_normalize: bool = True,
  181. patch_size: Dict[str, int] = None,
  182. max_patches: int = 2048,
  183. is_vqa: bool = False,
  184. **kwargs,
  185. ) -> None:
  186. super().__init__(**kwargs)
  187. self.patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
  188. self.do_normalize = do_normalize
  189. self.do_convert_rgb = do_convert_rgb
  190. self.max_patches = max_patches
  191. self.is_vqa = is_vqa
  192. def extract_flattened_patches(
  193. self,
  194. image: np.ndarray,
  195. max_patches: int,
  196. patch_size: dict,
  197. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  198. **kwargs,
  199. ) -> np.ndarray:
  200. """
  201. Extract flattened patches from an image.
  202. Args:
  203. image (`np.ndarray`):
  204. Image to extract flattened patches from.
  205. max_patches (`int`):
  206. Maximum number of patches to extract.
  207. patch_size (`dict`):
  208. Dictionary containing the patch height and width.
  209. Returns:
  210. result (`np.ndarray`):
  211. A sequence of `max_patches` flattened patches.
  212. """
  213. requires_backends(self.extract_flattened_patches, "torch")
  214. # convert to torch
  215. image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
  216. image = torch.from_numpy(image)
  217. patch_height, patch_width = patch_size["height"], patch_size["width"]
  218. image_height, image_width = get_image_size(image, ChannelDimension.FIRST)
  219. # maximize scale s.t.
  220. scale = math.sqrt(max_patches * (patch_height / image_height) * (patch_width / image_width))
  221. num_feasible_rows = max(min(math.floor(scale * image_height / patch_height), max_patches), 1)
  222. num_feasible_cols = max(min(math.floor(scale * image_width / patch_width), max_patches), 1)
  223. resized_height = max(num_feasible_rows * patch_height, 1)
  224. resized_width = max(num_feasible_cols * patch_width, 1)
  225. image = torch.nn.functional.interpolate(
  226. image.unsqueeze(0),
  227. size=(resized_height, resized_width),
  228. mode="bilinear",
  229. align_corners=False,
  230. antialias=True,
  231. ).squeeze(0)
  232. # [1, rows, columns, patch_height * patch_width * image_channels]
  233. patches = torch_extract_patches(image, patch_height, patch_width)
  234. patches_shape = patches.shape
  235. rows = patches_shape[1]
  236. columns = patches_shape[2]
  237. depth = patches_shape[3]
  238. # [rows * columns, patch_height * patch_width * image_channels]
  239. patches = patches.reshape([rows * columns, depth])
  240. # [rows * columns, 1]
  241. row_ids = torch.arange(rows).reshape([rows, 1]).repeat(1, columns).reshape([rows * columns, 1])
  242. col_ids = torch.arange(columns).reshape([1, columns]).repeat(rows, 1).reshape([rows * columns, 1])
  243. # Offset by 1 so the ids do not contain zeros, which represent padding.
  244. row_ids += 1
  245. col_ids += 1
  246. # Prepare additional patch features.
  247. # [rows * columns, 1]
  248. row_ids = row_ids.to(torch.float32)
  249. col_ids = col_ids.to(torch.float32)
  250. # [rows * columns, 2 + patch_height * patch_width * image_channels]
  251. result = torch.cat([row_ids, col_ids, patches], -1)
  252. # [max_patches, 2 + patch_height * patch_width * image_channels]
  253. result = torch.nn.functional.pad(result, [0, 0, 0, max_patches - (rows * columns)]).float()
  254. result = to_numpy_array(result)
  255. return result
  256. def normalize(
  257. self,
  258. image: np.ndarray,
  259. data_format: Optional[Union[str, ChannelDimension]] = None,
  260. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  261. **kwargs,
  262. ) -> np.ndarray:
  263. """
  264. Normalize an image. image = (image - image_mean) / image_std.
  265. The image std is to mimic the tensorflow implementation of the `per_image_standardization`:
  266. https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
  267. Args:
  268. image (`np.ndarray`):
  269. Image to normalize.
  270. data_format (`str` or `ChannelDimension`, *optional*):
  271. The channel dimension format for the output image. If unset, the channel dimension format of the input
  272. image is used.
  273. input_data_format (`str` or `ChannelDimension`, *optional*):
  274. The channel dimension format of the input image. If not provided, it will be inferred.
  275. """
  276. if image.dtype == np.uint8:
  277. image = image.astype(np.float32)
  278. # take mean across the whole `image`
  279. mean = np.mean(image)
  280. std = np.std(image)
  281. adjusted_stddev = max(std, 1.0 / math.sqrt(np.prod(image.shape)))
  282. return normalize(
  283. image,
  284. mean=mean,
  285. std=adjusted_stddev,
  286. data_format=data_format,
  287. input_data_format=input_data_format,
  288. **kwargs,
  289. )
  290. def preprocess(
  291. self,
  292. images: ImageInput,
  293. header_text: Optional[str] = None,
  294. do_convert_rgb: bool = None,
  295. do_normalize: Optional[bool] = None,
  296. max_patches: Optional[int] = None,
  297. patch_size: Optional[Dict[str, int]] = None,
  298. return_tensors: Optional[Union[str, TensorType]] = None,
  299. data_format: ChannelDimension = ChannelDimension.FIRST,
  300. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  301. **kwargs,
  302. ) -> ImageInput:
  303. """
  304. Preprocess an image or batch of images. The processor first computes the maximum possible number of
  305. aspect-ratio preserving patches of size `patch_size` that can be extracted from the image. It then pads the
  306. image with zeros to make the image respect the constraint of `max_patches`. Before extracting the patches the
  307. images are standardized following the tensorflow implementation of `per_image_standardization`
  308. (https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization).
  309. Args:
  310. images (`ImageInput`):
  311. Image to preprocess. Expects a single or batch of images.
  312. header_text (`Union[List[str], str]`, *optional*):
  313. Text to render as a header. Only has an effect if `image_processor.is_vqa` is `True`.
  314. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  315. Whether to convert the image to RGB.
  316. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  317. Whether to normalize the image.
  318. max_patches (`int`, *optional*, defaults to `self.max_patches`):
  319. Maximum number of patches to extract.
  320. patch_size (`dict`, *optional*, defaults to `self.patch_size`):
  321. Dictionary containing the patch height and width.
  322. return_tensors (`str` or `TensorType`, *optional*):
  323. The type of tensors to return. Can be one of:
  324. - Unset: Return a list of `np.ndarray`.
  325. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  326. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  327. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  328. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  329. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  330. The channel dimension format for the output image. Can be one of:
  331. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  332. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  333. - Unset: Use the channel dimension format of the input image.
  334. input_data_format (`ChannelDimension` or `str`, *optional*):
  335. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  336. from the input image. Can be one of:
  337. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  338. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  339. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  340. """
  341. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  342. do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  343. patch_size = patch_size if patch_size is not None else self.patch_size
  344. max_patches = max_patches if max_patches is not None else self.max_patches
  345. is_vqa = self.is_vqa
  346. if kwargs.get("data_format", None) is not None:
  347. raise ValueError("data_format is not an accepted input as the outputs are ")
  348. images = make_list_of_images(images)
  349. if not valid_images(images):
  350. raise ValueError(
  351. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  352. "torch.Tensor, tf.Tensor or jax.ndarray."
  353. )
  354. # PIL RGBA images are converted to RGB
  355. if do_convert_rgb:
  356. images = [convert_to_rgb(image) for image in images]
  357. # All transformations expect numpy arrays.
  358. images = [to_numpy_array(image) for image in images]
  359. if input_data_format is None:
  360. # We assume that all images have the same channel dimension format.
  361. input_data_format = infer_channel_dimension_format(images[0])
  362. if is_vqa:
  363. if header_text is None:
  364. raise ValueError("A header text must be provided for VQA models.")
  365. font_bytes = kwargs.pop("font_bytes", None)
  366. font_path = kwargs.pop("font_path", None)
  367. if isinstance(header_text, str):
  368. header_text = [header_text] * len(images)
  369. images = [
  370. render_header(image, header_text[i], font_bytes=font_bytes, font_path=font_path)
  371. for i, image in enumerate(images)
  372. ]
  373. if do_normalize:
  374. images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
  375. # convert to torch tensor and permute
  376. images = [
  377. self.extract_flattened_patches(
  378. image=image, max_patches=max_patches, patch_size=patch_size, input_data_format=input_data_format
  379. )
  380. for image in images
  381. ]
  382. # create attention mask in numpy
  383. attention_masks = [(image.sum(axis=-1) != 0).astype(np.float32) for image in images]
  384. encoded_outputs = BatchFeature(
  385. data={"flattened_patches": images, "attention_mask": attention_masks}, tensor_type=return_tensors
  386. )
  387. return encoded_outputs