image_processing_idefics3.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  17. import numpy as np
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature
  19. from ...image_transforms import PaddingMode, pad, to_channel_dimension_format, to_pil_image
  20. from ...image_utils import (
  21. IMAGENET_STANDARD_MEAN,
  22. IMAGENET_STANDARD_STD,
  23. ChannelDimension,
  24. ImageInput,
  25. PILImageResampling,
  26. get_image_size,
  27. infer_channel_dimension_format,
  28. is_scaled_image,
  29. is_valid_image,
  30. to_numpy_array,
  31. valid_images,
  32. validate_preprocess_arguments,
  33. )
  34. from ...utils import TensorType, is_vision_available, logging
  35. logger = logging.get_logger(__name__)
  36. if is_vision_available():
  37. import PIL
  38. from PIL import Image
  39. def _resize_output_size_rescale_to_max_len(
  40. height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
  41. ) -> Tuple[int, int]:
  42. """
  43. Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
  44. Args:
  45. height (`int`):
  46. Height of the input image.
  47. width (`int`):
  48. Width of the input image.
  49. min_len (`int`, *optional*, defaults to 1):
  50. Minimum size of the output image.
  51. max_len (`int`, *optional*, defaults to the maximum size of the image):
  52. Maximum size of the output image.
  53. Returns:
  54. The output size of the image after resizing.
  55. """
  56. max_len = max(height, width) if max_len is None else max_len
  57. aspect_ratio = width / height
  58. if width >= height:
  59. width = max_len
  60. height = int(width / aspect_ratio)
  61. if height % 2 != 0:
  62. height += 1
  63. elif height > width:
  64. height = max_len
  65. width = int(height * aspect_ratio)
  66. if width % 2 != 0:
  67. width += 1
  68. # Avoid resizing to a size smaller than min_len
  69. height = max(height, min_len)
  70. width = max(width, min_len)
  71. return height, width
  72. def _resize_output_size_scale_below_upper_bound(
  73. height: int, width: int, max_len: Optional[Dict[str, int]] = None
  74. ) -> Tuple[int, int]:
  75. """
  76. Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
  77. Args:
  78. height (`int`):
  79. Height of the input image.
  80. width (`int`):
  81. Width of the input image.
  82. max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
  83. Defines the maximum dimensions of the image.
  84. Returns:
  85. The output size of the image after resizing.
  86. """
  87. max_len = max(height, width) if max_len is None else max_len
  88. aspect_ratio = width / height
  89. if width >= height and width > max_len:
  90. width = max_len
  91. height = int(width / aspect_ratio)
  92. elif height > width and height > max_len:
  93. height = max_len
  94. width = int(height * aspect_ratio)
  95. # Avoid resizing to a size smaller than 1
  96. height = max(height, 1)
  97. width = max(width, 1)
  98. return height, width
  99. def get_resize_output_image_size(
  100. image,
  101. resolution_max_side: int,
  102. max_image_size: int = 1820,
  103. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  104. ) -> Tuple[int, int]:
  105. """
  106. Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
  107. Args:
  108. image (`np.ndarray`):
  109. Image to resize.
  110. resolution_max_side (`int`):
  111. The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
  112. input aspect ratio, with a lower bound of `min_image_size`.
  113. max_image_size (`int`, *optional*, defaults to 1820):
  114. Maximum image resolution. If the image is larger than this size, the longest edge will be resized to this
  115. value, with the shortest edge resized to keep the input aspect ratio, with a lower bound of `min_image_size`.
  116. input_data_format (`ChannelDimension` or `str`):
  117. The channel dimension format of the input image.
  118. Returns:
  119. The output size of the image after resizing.
  120. """
  121. if resolution_max_side > max_image_size:
  122. raise ValueError("`resolution_max_side` cannot be larger than `max_image_size`")
  123. height, width = get_image_size(image, channel_dim=input_data_format)
  124. # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
  125. height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
  126. # Find the output size when scaling the image to be below the max_image_size
  127. height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=max_image_size)
  128. return height, width
  129. # Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
  130. def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
  131. """
  132. Convert a single image or a list of images to a list of numpy arrays.
  133. Args:
  134. images (`ImageInput`):
  135. A single image or a list of images.
  136. Returns:
  137. A list of numpy arrays.
  138. """
  139. # If it's a single image, convert it to a list of lists
  140. if is_valid_image(images):
  141. images = [[images]]
  142. # If it's a list of images, it's a single batch, so convert it to a list of lists
  143. elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
  144. images = [images]
  145. # If it's a list of batches, it's already in the right format
  146. elif (
  147. isinstance(images, (list, tuple))
  148. and len(images) > 0
  149. and isinstance(images[0], (list, tuple))
  150. and is_valid_image(images[0][0])
  151. ):
  152. pass
  153. else:
  154. raise ValueError(
  155. "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
  156. )
  157. return images
  158. # Copied from transformers.models.detr.image_processing_detr.max_across_indices
  159. def max_across_indices(values: Iterable[Any]) -> List[Any]:
  160. """
  161. Return the maximum value across all indices of an iterable of values.
  162. """
  163. return [max(values_i) for values_i in zip(*values)]
  164. def get_max_height_width(
  165. images_list: List[List[np.ndarray]], input_data_format: Optional[Union[str, ChannelDimension]] = None
  166. ) -> List[int]:
  167. """
  168. Get the maximum height and width across all images in a batch.
  169. """
  170. if input_data_format is None:
  171. input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
  172. max_height = max_width = float("-inf")
  173. for images in images_list:
  174. for image in images:
  175. height, width = get_image_size(image, channel_dim=input_data_format)
  176. max_height = max(height, max_height)
  177. max_width = max(width, max_width)
  178. return (max_height, max_width)
  179. # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
  180. def make_pixel_mask(
  181. image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
  182. ) -> np.ndarray:
  183. """
  184. Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
  185. Args:
  186. image (`np.ndarray`):
  187. Image to make the pixel mask for.
  188. output_size (`Tuple[int, int]`):
  189. Output size of the mask.
  190. """
  191. input_height, input_width = get_image_size(image, channel_dim=input_data_format)
  192. mask = np.zeros(output_size, dtype=np.int64)
  193. mask[:input_height, :input_width] = 1
  194. return mask
  195. def convert_to_rgb(
  196. image: np.ndarray,
  197. palette: Optional[PIL.ImagePalette.ImagePalette] = None,
  198. data_format: Optional[Union[str, ChannelDimension]] = None,
  199. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  200. ) -> ImageInput:
  201. """
  202. Converts an image to RGB format.
  203. Args:
  204. image (`np.ndarray`):
  205. The image to convert.
  206. palette (List[int], *optional*):
  207. The palette to use if given.
  208. data_format (ChannelDimension or str, *optional*):
  209. The channel dimension format for the output image. If not provided, it will be the same as the input image.
  210. input_data_format (ChannelDimension or str, *optional*):
  211. The channel dimension format of the input image.
  212. """
  213. if input_data_format is None:
  214. input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
  215. # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
  216. # The resized image from PIL will always have channels last, so find the input format first.
  217. data_format = input_data_format if data_format is None else data_format
  218. mode = "P" if palette is not None else None
  219. image = to_pil_image(image, image_mode=mode)
  220. if image.mode == "P" and palette is not None:
  221. image.putpalette(palette)
  222. image_rgba = image.convert("RGBA")
  223. background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
  224. alpha_composite = Image.alpha_composite(background, image_rgba)
  225. alpha_composite = alpha_composite.convert("RGB")
  226. output_array = np.array(alpha_composite)
  227. # The image is always in channels last format after converting from a PIL image
  228. output_array = to_channel_dimension_format(output_array, data_format, input_channel_dim=ChannelDimension.LAST)
  229. return output_array
  230. # FIXME Amy: make a more general crop function that isn't just centre crop
  231. def _crop(
  232. image: np.ndarray,
  233. w1: int,
  234. h1: int,
  235. w2: int,
  236. h2: int,
  237. data_format: Optional[Union[str, ChannelDimension]] = None,
  238. ) -> np.ndarray:
  239. if data_format is None:
  240. data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
  241. if data_format == ChannelDimension.FIRST:
  242. image = image[:, h1:h2, w1:w2]
  243. elif data_format == ChannelDimension.LAST:
  244. image = image[h1:h2, w1:w2, :]
  245. else:
  246. raise ValueError("Invalid channel dimension format.")
  247. return image
  248. class Idefics3ImageProcessor(BaseImageProcessor):
  249. r"""
  250. Constructs a Idefics3 image processor.
  251. Args:
  252. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  253. Whether to convert the image to RGB. This is useful if the input image is of a different format e.g. RGBA.
  254. Only has an effect if the input image is in the PIL format.
  255. do_resize (`bool`, *optional*, defaults to `True`):
  256. Whether to resize the image. The longest edge of the image is resized to be <= `size["longest_edge"]`, with the
  257. shortest edge resized to keep the input aspect ratio.
  258. size (`Dict`, *optional*, defaults to `{"longest_edge": 4 * 364}`):
  259. Controls the size of the output image. This is a dictionary containing the key "longest_edge".
  260. The image will be resized such that the longest edge is <= `size["longest_edge"]` and the shortest edge is resized
  261. to keep the input aspect ratio.
  262. resample (`Resampling`, *optional*, defaults to `Resampling.LANCZOS`):
  263. Resampling filter to use when resizing the image.
  264. do_image_splitting (`bool`, *optional*, defaults to `True`):
  265. Whether to split the image into sub-images concatenated with the original image. They are split into patches
  266. such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
  267. max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
  268. Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
  269. do_rescale (`bool`, *optional*, defaults to `True`):
  270. Whether to rescale the image. If set to `True`, the image is rescaled to have pixel values between 0 and 1.
  271. rescale_factor (`float`, *optional*, defaults to `1/255`):
  272. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  273. do_normalize (`bool`, *optional*, defaults to `True`):
  274. Whether to normalize the image. If set to `True`, the image is normalized to have a mean of `image_mean` and
  275. a standard deviation of `image_std`.
  276. image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
  277. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  278. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
  279. overridden by the `image_mean` parameter in the `preprocess` method.
  280. image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
  281. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  282. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  283. Can be overridden by the `image_std` parameter in the `preprocess` method.
  284. do_pad (`bool`, *optional*, defaults to `True`):
  285. Whether or not to pad the images to the largest height and width in the batch and number of images per
  286. sample in the batch, such that the returned tensor is of shape (batch_size, max_num_images, num_channels, max_height, max_width).
  287. """
  288. model_input_names = ["pixel_values"]
  289. def __init__(
  290. self,
  291. do_convert_rgb: bool = True,
  292. do_resize: bool = True,
  293. size: Dict[str, int] = None,
  294. resample: PILImageResampling = PILImageResampling.LANCZOS,
  295. do_image_splitting: bool = True,
  296. max_image_size: Dict[str, int] = None,
  297. do_rescale: bool = True,
  298. rescale_factor: float = 1 / 255,
  299. do_normalize: bool = True,
  300. image_mean: Optional[Union[float, List[float]]] = None,
  301. image_std: Optional[Union[float, List[float]]] = None,
  302. do_pad: bool = True,
  303. **kwargs,
  304. ) -> None:
  305. super().__init__(**kwargs)
  306. self.do_convert_rgb = do_convert_rgb
  307. self.do_resize = do_resize
  308. self.size = size if size is not None else {"longest_edge": 4 * 364}
  309. self.resample = resample
  310. self.do_image_splitting = do_image_splitting
  311. self.max_image_size = max_image_size if max_image_size is not None else {"longest_edge": 364}
  312. self.do_rescale = do_rescale
  313. self.rescale_factor = rescale_factor
  314. self.do_normalize = do_normalize
  315. self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
  316. self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
  317. self.do_pad = do_pad
  318. def resize(
  319. self,
  320. image: np.ndarray,
  321. size: Dict[str, int],
  322. resample: PILImageResampling = PILImageResampling.LANCZOS,
  323. data_format: Optional[Union[str, ChannelDimension]] = None,
  324. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  325. **kwargs,
  326. ) -> np.ndarray:
  327. """
  328. Resize an image. The longest edge of the image is resized to size["longest_edge"], with the shortest edge
  329. resized to keep the input aspect ratio. Can also be used with size["height"] and size["width"].
  330. Args:
  331. image (`np.ndarray`):
  332. Image to resize.
  333. size (`Dict[str, int]`):
  334. Size of the output image.
  335. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
  336. Resampling filter to use when resizing the image.
  337. data_format (`ChannelDimension` or `str`, *optional*):
  338. The channel dimension format of the output image. If not provided, it will be the same as the input image.
  339. input_data_format (`ChannelDimension` or `str`, *optional*):
  340. The channel dimension format of the input image. If not provided, it will be inferred.
  341. """
  342. if input_data_format is None:
  343. input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))
  344. # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
  345. # The resized image from PIL will always have channels last, so find the input format first.
  346. data_format = input_data_format if data_format is None else data_format
  347. if "longest_edge" in size:
  348. size = get_resize_output_image_size(
  349. image, resolution_max_side=size["longest_edge"], input_data_format=input_data_format
  350. )
  351. elif "height" in size and "width" in size:
  352. size = (size["height"], size["width"])
  353. else:
  354. raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
  355. image_mode = None
  356. if image.ndim == 2 or image.shape[-1] == 1:
  357. image_mode = "P"
  358. image = to_pil_image(image, image_mode=image_mode)
  359. resized_image = image.resize((size[1], size[0]), resample=resample)
  360. resized_image = np.array(resized_image)
  361. # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
  362. # so we need to add it back if necessary.
  363. resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
  364. # The image is always in channels last format after converting from a PIL image
  365. resized_image = to_channel_dimension_format(
  366. resized_image, data_format, input_channel_dim=ChannelDimension.LAST
  367. )
  368. return resized_image
  369. def split_image(
  370. self,
  371. image,
  372. max_image_size: Dict[str, int],
  373. resample: PILImageResampling = PILImageResampling.LANCZOS,
  374. data_format: Optional[Union[str, ChannelDimension]] = None,
  375. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  376. ):
  377. """
  378. Split an image into squares of side max_image_size and the original image resized to max_image_size.
  379. That means that a single image becomes a sequence of images.
  380. This is a "trick" to spend more compute on each image with no changes in the vision encoder.
  381. 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
  382. 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
  383. sub-images of the same size each (image_size, image_size). Typically, 364x364.
  384. 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
  385. Args:
  386. image (`np.ndarray`):
  387. Images to split.
  388. max_image_size (`Dict[str, int]`):
  389. Maximum size of the output image. If the image is larger than this size, it will be split into
  390. patches of this size, and the original image will be concatenated with the patches, resized to max_size.
  391. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
  392. Resampling filter to use when resizing the image.
  393. data_format (`ChannelDimension` or `str`, *optional*):
  394. The channel dimension format of the output image. If not provided, it will be the same as the input image.
  395. input_data_format (`ChannelDimension` or `str`, *optional*):
  396. The channel dimension format of the input image. If not provided, it will be inferred.
  397. """
  398. height, width = get_image_size(image, channel_dim=input_data_format)
  399. max_height = max_width = max_image_size["longest_edge"]
  400. frames = []
  401. if height > max_height or width > max_width:
  402. # Calculate the number of splits
  403. num_splits_h = math.ceil(height / max_height)
  404. num_splits_w = math.ceil(width / max_width)
  405. # Calculate the optimal width and height for the sub-images
  406. optimal_height = math.ceil(height / num_splits_h)
  407. optimal_width = math.ceil(width / num_splits_w)
  408. # Iterate through each row and column
  409. for r in range(num_splits_h):
  410. for c in range(num_splits_w):
  411. # Calculate the starting point of the crop
  412. start_x = c * optimal_width
  413. start_y = r * optimal_height
  414. # Calculate the ending point of the crop
  415. end_x = min(start_x + optimal_width, width)
  416. end_y = min(start_y + optimal_height, height)
  417. # Crop the image
  418. cropped_image = _crop(
  419. image,
  420. start_x,
  421. start_y,
  422. end_x,
  423. end_y,
  424. data_format=data_format,
  425. )
  426. frames.append(cropped_image)
  427. # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
  428. global_image_height, global_image_width = max_height, max_width
  429. if height != global_image_height or width != global_image_width:
  430. image = self.resize(
  431. image,
  432. {"height": global_image_height, "width": global_image_width},
  433. resample=resample,
  434. input_data_format=data_format,
  435. )
  436. else:
  437. num_splits_h, num_splits_w = 0, 0
  438. frames.append(image)
  439. return frames, num_splits_h, num_splits_w
  440. def resize_for_vision_encoder(
  441. self,
  442. image: np.ndarray,
  443. vision_encoder_max_size: int,
  444. resample: PILImageResampling = PILImageResampling.LANCZOS,
  445. data_format: Optional[Union[str, ChannelDimension]] = None,
  446. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  447. ):
  448. """
  449. Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
  450. Args:
  451. image (`np.ndarray`):
  452. Images to resize.
  453. vision_encoder_max_size (`int`):
  454. Maximum size of the output image. If the image is larger than this size, it will be split into
  455. patches of this size, and the original image will be concatenated with the patches, resized to max_size.
  456. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
  457. Resampling filter to use when resizing the image.
  458. data_format (`ChannelDimension` or `str`, *optional*):
  459. The channel dimension format of the output image. If not provided, it will be the same as the input image.
  460. input_data_format (`ChannelDimension` or `str`, *optional*):
  461. The channel dimension format of the input image. If not provided, it will be inferred
  462. """
  463. height, width = get_image_size(image, channel_dim=input_data_format)
  464. aspect_ratio = width / height
  465. if width >= height:
  466. width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
  467. height = int(width / aspect_ratio)
  468. height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
  469. elif height > width:
  470. height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
  471. width = int(height * aspect_ratio)
  472. width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
  473. new_size = {"height": height, "width": width}
  474. return self.resize(
  475. image, size=new_size, resample=resample, input_data_format=input_data_format, data_format=data_format
  476. )
  477. def _pad_image(
  478. self,
  479. image: np.ndarray,
  480. output_size: Tuple[int, int],
  481. constant_values: Union[float, Iterable[float]] = 0,
  482. data_format: Optional[ChannelDimension] = None,
  483. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  484. ) -> np.ndarray:
  485. """
  486. Pad an image with zeros to the given size.
  487. """
  488. input_height, input_width = get_image_size(image, channel_dim=input_data_format)
  489. output_height, output_width = output_size
  490. pad_bottom = output_height - input_height
  491. pad_right = output_width - input_width
  492. padding = ((0, pad_bottom), (0, pad_right))
  493. padded_image = pad(
  494. image,
  495. padding,
  496. mode=PaddingMode.CONSTANT,
  497. constant_values=constant_values,
  498. data_format=data_format,
  499. input_data_format=input_data_format,
  500. )
  501. return padded_image
  502. def pad(
  503. self,
  504. images: List[np.ndarray],
  505. constant_values: Union[float, Iterable[float]] = 0,
  506. return_pixel_mask: bool = True,
  507. return_tensors: Optional[Union[str, TensorType]] = None,
  508. data_format: Optional[ChannelDimension] = None,
  509. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  510. ) -> BatchFeature:
  511. """
  512. For a list of images, for each images, pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width.
  513. For each sample in the batch, pads the sample with empty images to the max_number of images per sample in the batch. Optionally returns a pixel mask.
  514. Args:
  515. images (`List[np.ndarray]`):
  516. List of list of images to pad. Pads to the largest height and width in the batch.
  517. constant_values (`float` or `Iterable[float]`, *optional*):
  518. The value to use for the padding if `mode` is `"constant"`.
  519. return_pixel_mask (`bool`, *optional*, defaults to `True`):
  520. Whether to return a pixel mask.
  521. return_tensors (`str` or `TensorType`, *optional*):
  522. The type of tensors to return. Can be one of:
  523. - Unset: Return a list of `np.ndarray`.
  524. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  525. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  526. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  527. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  528. data_format (`str` or `ChannelDimension`, *optional*):
  529. The channel dimension format of the image. If not provided, it will be the same as the input image.
  530. input_data_format (`ChannelDimension` or `str`, *optional*):
  531. The channel dimension format of the input image. If not provided, it will be inferred.
  532. """
  533. pad_size = get_max_height_width(images, input_data_format=input_data_format)
  534. batch_size = len(images)
  535. max_num_images = max(len(images_) for images_ in images)
  536. input_data_format = (
  537. infer_channel_dimension_format(images[0][0], num_channels=(1, 3, 4))
  538. if input_data_format is None
  539. else input_data_format
  540. )
  541. data_format = input_data_format if data_format is None else data_format
  542. if input_data_format == ChannelDimension.FIRST:
  543. n_channels = images[0][0].shape[0]
  544. elif input_data_format == ChannelDimension.LAST:
  545. n_channels = images[0][0].shape[-1]
  546. else:
  547. raise ValueError("Invalid channel dimension format.")
  548. def empty_image(size, input_data_format):
  549. if input_data_format == ChannelDimension.FIRST:
  550. return np.zeros((n_channels, *size), dtype=np.uint8)
  551. elif input_data_format == ChannelDimension.LAST:
  552. return np.zeros((*size, n_channels), dtype=np.uint8)
  553. padded_images_list = [
  554. [empty_image(pad_size, data_format) for _ in range(max_num_images)] for _ in range(batch_size)
  555. ]
  556. padded_masks = [[np.zeros(pad_size) for _ in range(max_num_images)] for _ in range(batch_size)]
  557. for batch_idx in range(batch_size):
  558. for sample_idx, image in enumerate(images[batch_idx]):
  559. padded_images_list[batch_idx][sample_idx] = self._pad_image(
  560. image,
  561. pad_size,
  562. constant_values=constant_values,
  563. data_format=data_format,
  564. input_data_format=input_data_format,
  565. )
  566. padded_masks[batch_idx][sample_idx] = make_pixel_mask(
  567. image, output_size=pad_size, input_data_format=input_data_format
  568. )
  569. padded_masks = padded_masks if return_pixel_mask else None
  570. return padded_images_list, padded_masks
  571. def preprocess(
  572. self,
  573. images: ImageInput,
  574. do_convert_rgb: Optional[bool] = None,
  575. do_resize: Optional[bool] = None,
  576. size: Optional[Dict[str, int]] = None,
  577. resample: PILImageResampling = None,
  578. do_image_splitting: Optional[bool] = None,
  579. do_rescale: Optional[bool] = None,
  580. max_image_size: Optional[Dict[str, int]] = None,
  581. rescale_factor: Optional[float] = None,
  582. do_normalize: Optional[bool] = None,
  583. image_mean: Optional[Union[float, List[float]]] = None,
  584. image_std: Optional[Union[float, List[float]]] = None,
  585. do_pad: Optional[bool] = None,
  586. return_tensors: Optional[Union[str, TensorType]] = None,
  587. return_row_col_info: bool = False,
  588. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  589. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  590. ):
  591. """
  592. Preprocess a batch of images.
  593. Args:
  594. images (`ImageInput`):
  595. A list of images to preprocess.
  596. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  597. Whether to convert the image to RGB.
  598. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  599. Whether to resize the image.
  600. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  601. Size of the image after resizing. With the longest edge resized to keep the input aspect ratio.
  602. resample (`int`, *optional*, defaults to `self.resample`):
  603. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  604. has an effect if `do_resize` is set to `True`.
  605. do_image_splitting (`bool`, *optional*, defaults to `self.do_image_splitting`):
  606. Whether to split the image into sub-images concatenated with the original image. They are split into patches
  607. such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
  608. max_image_size (`Dict`, *optional*, defaults to `self.max_image_size`):
  609. Maximum resolution of the images. If the image is larger than this size, the image is split into patches.
  610. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  611. Whether to rescale the image.
  612. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  613. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  614. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  615. Whether to normalize the image.
  616. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  617. Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
  618. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  619. Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
  620. `True`.
  621. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  622. Whether or not to pad the images to the largest height and width in the batch.
  623. return_tensors (`str` or `TensorType`, *optional*):
  624. The type of tensors to return. Can be one of:
  625. - Unset: Return a list of `np.ndarray`.
  626. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  627. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  628. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  629. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  630. return_row_col_info (`bool`, *optional*, default to `False`):
  631. Whether to return the number of rows and columns of the split images. This is used for the
  632. `Idefics3Processor` to generate prompt strings based on the number of rows and columns.
  633. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  634. The channel dimension format for the output image. Can be one of:
  635. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  636. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  637. - Unset: Use the channel dimension format of the input image.
  638. input_data_format (`ChannelDimension` or `str`, *optional*):
  639. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  640. from the input image. Can be one of:
  641. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  642. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  643. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  644. """
  645. do_resize = do_resize if do_resize is not None else self.do_resize
  646. size = size if size is not None else self.size
  647. resample = resample if resample is not None else self.resample
  648. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  649. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  650. do_image_splitting = do_image_splitting if do_image_splitting is not None else self.do_image_splitting
  651. max_image_size = max_image_size if max_image_size is not None else self.max_image_size
  652. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  653. image_mean = image_mean if image_mean is not None else self.image_mean
  654. image_std = image_std if image_std is not None else self.image_std
  655. do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  656. do_pad = do_pad if do_pad is not None else self.do_pad
  657. images_list = make_list_of_images(images)
  658. if not valid_images(images_list[0]):
  659. raise ValueError(
  660. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  661. "torch.Tensor, tf.Tensor or jax.ndarray."
  662. )
  663. validate_preprocess_arguments(
  664. do_rescale=do_rescale,
  665. rescale_factor=rescale_factor,
  666. do_normalize=do_normalize,
  667. image_mean=image_mean,
  668. image_std=image_std,
  669. do_resize=do_resize,
  670. size=size,
  671. resample=resample,
  672. )
  673. # save the palettes for conversion to RGB
  674. palettes_list = [
  675. [im.getpalette() if isinstance(im, Image.Image) and im.mode == "P" else None for im in images]
  676. for images in images_list
  677. ]
  678. # All transformations expect numpy arrays.
  679. images_list = [[to_numpy_array(image) for image in images] for images in images_list]
  680. if is_scaled_image(images_list[0][0]) and do_rescale:
  681. logger.warning_once(
  682. "It looks like you are trying to rescale already rescaled images. If the input"
  683. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  684. )
  685. # We assume that all images have the same channel dimension format.
  686. if input_data_format is None:
  687. input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))
  688. # Extra channel dimension for grayscale images
  689. if input_data_format == ChannelDimension.LAST:
  690. images_list = [
  691. [np.expand_dims(img, axis=-1) if img.ndim == 2 else img for img in images] for images in images_list
  692. ]
  693. elif input_data_format == ChannelDimension.FIRST:
  694. images_list = [
  695. [np.expand_dims(img, axis=0) if img.ndim == 2 else img for img in images] for images in images_list
  696. ]
  697. else:
  698. raise ValueError(f"Invalid channel dimension format {input_data_format}.")
  699. if do_resize:
  700. images_list = [
  701. [
  702. self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
  703. for image in images
  704. ]
  705. for images in images_list
  706. ]
  707. if do_image_splitting:
  708. # We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
  709. # for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
  710. # for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
  711. images_list = [
  712. [
  713. self.resize_for_vision_encoder(
  714. image, max_image_size["longest_edge"], resample=resample, input_data_format=input_data_format
  715. )
  716. for image in images
  717. ]
  718. for images in images_list
  719. ]
  720. images_list_split_arrays = []
  721. palettes_list_split_arrays = []
  722. images_list_rows = []
  723. images_list_cols = []
  724. for images, palettes in zip(images_list, palettes_list):
  725. split_image_arrays = []
  726. split_palettes_arrays = []
  727. image_rows = []
  728. image_cols = []
  729. for image, palette in zip(images, palettes):
  730. split_image_array, rows, cols = self.split_image(
  731. image,
  732. max_image_size=max_image_size,
  733. input_data_format=input_data_format,
  734. )
  735. split_image_arrays.extend(split_image_array)
  736. split_palettes_arrays.extend([palette] * len(split_image_array))
  737. image_rows.append(rows)
  738. image_cols.append(cols)
  739. images_list_split_arrays.append(split_image_arrays)
  740. palettes_list_split_arrays.append(split_palettes_arrays)
  741. images_list_rows.append(image_rows)
  742. images_list_cols.append(image_cols)
  743. images_list = images_list_split_arrays
  744. palettes_list = palettes_list_split_arrays
  745. else:
  746. # We square the images to max_image_size
  747. images_list = [
  748. [
  749. self.resize(
  750. image=image,
  751. size={"height": max_image_size["longest_edge"], "width": max_image_size["longest_edge"]},
  752. resample=resample,
  753. input_data_format=input_data_format,
  754. )
  755. for image in images
  756. ]
  757. for images in images_list
  758. ]
  759. images_list_rows = [[0] * len(images) for images in images_list]
  760. images_list_cols = [[0] * len(images) for images in images_list]
  761. if do_convert_rgb:
  762. images_list = [
  763. [convert_to_rgb(img, palette) for img, palette in zip(images, palettes)]
  764. for images, palettes in zip(images_list, palettes_list)
  765. ]
  766. if do_rescale:
  767. images_list = [
  768. [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
  769. for images in images_list
  770. ]
  771. if do_normalize:
  772. images_list = [
  773. [
  774. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  775. for image in images
  776. ]
  777. for images in images_list
  778. ]
  779. pixel_attention_mask = None
  780. if do_pad:
  781. images_list, pixel_attention_mask = self.pad(
  782. images_list, return_pixel_mask=True, return_tensors=return_tensors, input_data_format=input_data_format
  783. )
  784. if data_format is not None:
  785. images_list = [
  786. [
  787. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
  788. for image in images
  789. ]
  790. for images in images_list
  791. ]
  792. # Faster tensor conversion
  793. data = {"pixel_values": np.array(images_list) if do_pad and return_tensors is not None else images_list}
  794. if pixel_attention_mask is not None:
  795. data["pixel_attention_mask"] = (
  796. np.array(pixel_attention_mask) if do_pad and return_tensors is not None else pixel_attention_mask
  797. )
  798. encoding = BatchFeature(data=data, tensor_type=return_tensors)
  799. # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
  800. if return_row_col_info:
  801. encoding["rows"] = images_list_rows
  802. encoding["cols"] = images_list_cols
  803. return encoding