image_processing_pixtral.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Pixtral."""
  16. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  17. import numpy as np
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  19. from ...image_transforms import (
  20. resize,
  21. to_channel_dimension_format,
  22. )
  23. from ...image_utils import (
  24. ChannelDimension,
  25. ImageInput,
  26. PILImageResampling,
  27. get_image_size,
  28. infer_channel_dimension_format,
  29. is_scaled_image,
  30. is_valid_image,
  31. to_numpy_array,
  32. valid_images,
  33. validate_kwargs,
  34. validate_preprocess_arguments,
  35. )
  36. from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
  37. from ...utils.import_utils import requires_backends
  38. logger = logging.get_logger(__name__)
  39. if is_vision_available():
  40. import PIL
  41. class BatchMixFeature(BatchFeature):
  42. def to(self, *args, **kwargs) -> "BatchMixFeature":
  43. """
  44. Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
  45. different `dtypes` and sending the `BatchFeature` to a different `device`.
  46. Args:
  47. args (`Tuple`):
  48. Will be passed to the `to(...)` function of the tensors.
  49. kwargs (`Dict`, *optional*):
  50. Will be passed to the `to(...)` function of the tensors.
  51. Returns:
  52. [`BatchFeature`]: The same instance after modification.
  53. """
  54. requires_backends(self, ["torch"])
  55. import torch # noqa
  56. new_data = {}
  57. device = kwargs.get("device")
  58. # Check if the args are a device or a dtype
  59. if device is None and len(args) > 0:
  60. # device should be always the first argument
  61. arg = args[0]
  62. if is_torch_dtype(arg):
  63. # The first argument is a dtype
  64. pass
  65. elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
  66. device = arg
  67. else:
  68. # it's something else
  69. raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
  70. # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  71. for k, v in self.items():
  72. # check if v is a floating point
  73. if isinstance(v, list):
  74. new_data[k] = [
  75. element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
  76. ]
  77. elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
  78. # cast and send to device
  79. new_data[k] = v.to(*args, **kwargs)
  80. elif isinstance(v, torch.Tensor) and device is not None:
  81. new_data[k] = v.to(device=device)
  82. else:
  83. new_data[k] = v
  84. self.data = new_data
  85. return self
  86. # Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
  87. def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
  88. """
  89. Convert a single image or a list of images to a list of numpy arrays.
  90. Args:
  91. images (`ImageInput`):
  92. A single image or a list of images.
  93. Returns:
  94. A list of numpy arrays.
  95. """
  96. # If it's a single image, convert it to a list of lists
  97. if is_valid_image(images):
  98. images = [[images]]
  99. # If it's a list of images, it's a single batch, so convert it to a list of lists
  100. elif isinstance(images, (list, tuple)) and len(images) > 0 and is_valid_image(images[0]):
  101. images = [images]
  102. # If it's a list of batches, it's already in the right format
  103. elif (
  104. isinstance(images, (list, tuple))
  105. and len(images) > 0
  106. and isinstance(images[0], (list, tuple))
  107. and is_valid_image(images[0][0])
  108. ):
  109. pass
  110. else:
  111. raise ValueError(
  112. "Invalid input type. Must be a single image, a list of images, or a list of batches of images."
  113. )
  114. return images
  115. # Adapted from function in image_transforms.py to ensure any transparent pixels are converted to white.
  116. def convert_to_rgb(image: ImageInput) -> ImageInput:
  117. """
  118. Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
  119. as is.
  120. Args:
  121. image (Image):
  122. The image to convert.
  123. """
  124. requires_backends(convert_to_rgb, ["vision"])
  125. if not isinstance(image, PIL.Image.Image):
  126. return image
  127. if image.mode == "RGB":
  128. return image
  129. # First we convert to RGBA to set background to white.
  130. image = image.convert("RGBA")
  131. # Create a new image with a white background.
  132. new_image = PIL.Image.new("RGBA", image.size, "WHITE")
  133. new_image.paste(image, (0, 0), image)
  134. new_image = new_image.convert("RGB")
  135. return new_image
  136. def _num_image_tokens(image_size: Tuple[int, int], patch_size: Tuple[int, int]) -> int:
  137. """
  138. Calculate the number of image tokens given the image size and patch size.
  139. Args:
  140. image_size (`Tuple[int, int]`):
  141. The size of the image as `(height, width)`.
  142. patch_size (`Tuple[int, int]`):
  143. The patch size as `(height, width)`.
  144. Returns:
  145. `int`: The number of image tokens.
  146. """
  147. height, width = image_size
  148. patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
  149. num_width_tokens = (width - 1) // patch_width + 1
  150. num_height_tokens = (height - 1) // patch_height + 1
  151. return num_height_tokens, num_width_tokens
  152. def get_resize_output_image_size(
  153. input_image: np.ndarray,
  154. size: Union[int, Tuple[int, int], List[int], Tuple[int]],
  155. patch_size: Union[int, Tuple[int, int], List[int], Tuple[int]],
  156. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  157. ) -> tuple:
  158. """
  159. Find the target (height, width) dimension of the output image after resizing given the input image and the desired
  160. size.
  161. Args:
  162. input_image (`np.ndarray`):
  163. The image to resize.
  164. size (`int` or `Tuple[int, int]`):
  165. Max image size an input image can be. Must be a dictionary with the key "longest_edge".
  166. patch_size (`int` or `Tuple[int, int]`):
  167. The patch_size as `(height, width)` to use for resizing the image. If patch_size is an integer, `(patch_size, patch_size)`
  168. will be used
  169. input_data_format (`ChannelDimension`, *optional*):
  170. The channel dimension format of the input image. If unset, will use the inferred format from the input.
  171. Returns:
  172. `tuple`: The target (height, width) dimension of the output image after resizing.
  173. """
  174. max_height, max_width = size if isinstance(size, (tuple, list)) else (size, size)
  175. patch_height, patch_width = patch_size if isinstance(patch_size, (tuple, list)) else (patch_size, patch_size)
  176. height, width = get_image_size(input_image, input_data_format)
  177. ratio = max(height / max_height, width / max_width)
  178. if ratio > 1:
  179. # Orgiginal implementation uses `round` which utilises bankers rounding, which can lead to surprising results
  180. height = int(np.ceil(height / ratio))
  181. width = int(np.ceil(width / ratio))
  182. num_height_tokens, num_width_tokens = _num_image_tokens((height, width), (patch_height, patch_width))
  183. return num_height_tokens * patch_height, num_width_tokens * patch_width
  184. # Hack to get tensor conversion used in BatchFeature without batching the images
  185. def _get_is_as_tensor_fns(tensor_type: Union[str, TensorType]) -> Tuple[Callable, Callable]:
  186. return BatchFeature()._get_is_as_tensor_fns(tensor_type)
  187. def convert_to_tensor(array, tensor_type: Union[str, TensorType]) -> Any:
  188. is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type)
  189. if is_tensor(array):
  190. return array
  191. return as_tensor(array)
  192. class PixtralImageProcessor(BaseImageProcessor):
  193. r"""
  194. Constructs a Pixtral image processor.
  195. Args:
  196. do_resize (`bool`, *optional*, defaults to `True`):
  197. Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
  198. `do_resize` in the `preprocess` method.
  199. size (`Dict[str, int]` *optional*, defaults to `{"longest_edge": 1024}`):
  200. Size of the maximum dimension of either the height or width dimension of the image. Used to control how
  201. images are resized. If either the height or width are greater than `size["longest_edge"]` then both the height and width are rescaled by `height / ratio`, `width /ratio` where `ratio = max(height / longest_edge, width / longest_edge)`
  202. patch_size (`Dict[str, int]` *optional*, defaults to `{"height": 16, "width": 16}`):
  203. Size of the patches in the model, used to calculate the output image size. Can be overridden by `patch_size` in the `preprocess` method.
  204. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
  205. Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
  206. do_rescale (`bool`, *optional*, defaults to `True`):
  207. Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
  208. the `preprocess` method.
  209. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  210. Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
  211. method.
  212. do_normalize (`bool`, *optional*, defaults to `True`):
  213. Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
  214. image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
  215. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  216. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  217. image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
  218. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  219. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  220. Can be overridden by the `image_std` parameter in the `preprocess` method.
  221. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  222. Whether to convert the image to RGB.
  223. """
  224. model_input_names = ["pixel_values"]
  225. def __init__(
  226. self,
  227. do_resize: bool = True,
  228. size: Dict[str, int] = None,
  229. patch_size: Dict[str, int] = None,
  230. resample: PILImageResampling = PILImageResampling.BICUBIC,
  231. do_rescale: bool = True,
  232. rescale_factor: Union[int, float] = 1 / 255,
  233. do_normalize: bool = True,
  234. image_mean: Optional[Union[float, List[float]]] = None,
  235. image_std: Optional[Union[float, List[float]]] = None,
  236. do_convert_rgb: bool = True,
  237. **kwargs,
  238. ) -> None:
  239. super().__init__(**kwargs)
  240. size = size if size is not None else {"longest_edge": 1024}
  241. patch_size = patch_size if patch_size is not None else {"height": 16, "width": 16}
  242. patch_size = get_size_dict(patch_size, default_to_square=True)
  243. self.do_resize = do_resize
  244. self.size = size
  245. self.patch_size = patch_size
  246. self.resample = resample
  247. self.do_rescale = do_rescale
  248. self.rescale_factor = rescale_factor
  249. self.do_normalize = do_normalize
  250. self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
  251. self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
  252. self.do_convert_rgb = do_convert_rgb
  253. self._valid_processor_keys = [
  254. "images",
  255. "do_resize",
  256. "size",
  257. "patch_size",
  258. "resample",
  259. "do_rescale",
  260. "rescale_factor",
  261. "do_normalize",
  262. "image_mean",
  263. "image_std",
  264. "do_convert_rgb",
  265. "return_tensors",
  266. "data_format",
  267. "input_data_format",
  268. ]
  269. def resize(
  270. self,
  271. image: np.ndarray,
  272. size: Dict[str, int],
  273. patch_size: Dict[str, int],
  274. resample: PILImageResampling = PILImageResampling.BICUBIC,
  275. data_format: Optional[Union[str, ChannelDimension]] = None,
  276. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  277. **kwargs,
  278. ) -> np.ndarray:
  279. """
  280. Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
  281. resized to keep the input aspect ratio.
  282. Args:
  283. image (`np.ndarray`):
  284. Image to resize.
  285. size (`Dict[str, int]`):
  286. Dict containing the longest possible edge of the image.
  287. patch_size (`Dict[str, int]`):
  288. Patch size used to calculate the size of the output image.
  289. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  290. Resampling filter to use when resiizing the image.
  291. data_format (`str` or `ChannelDimension`, *optional*):
  292. The channel dimension format of the image. If not provided, it will be the same as the input image.
  293. input_data_format (`ChannelDimension` or `str`, *optional*):
  294. The channel dimension format of the input image. If not provided, it will be inferred.
  295. """
  296. if "longest_edge" in size:
  297. size = (size["longest_edge"], size["longest_edge"])
  298. elif "height" in size and "width" in size:
  299. size = (size["height"], size["width"])
  300. else:
  301. raise ValueError("size must contain either 'longest_edge' or 'height' and 'width'.")
  302. if "height" in patch_size and "width" in patch_size:
  303. patch_size = (patch_size["height"], patch_size["width"])
  304. else:
  305. raise ValueError("patch_size must contain either 'shortest_edge' or 'height' and 'width'.")
  306. output_size = get_resize_output_image_size(
  307. image,
  308. size=size,
  309. patch_size=patch_size,
  310. input_data_format=input_data_format,
  311. )
  312. return resize(
  313. image,
  314. size=output_size,
  315. resample=resample,
  316. data_format=data_format,
  317. input_data_format=input_data_format,
  318. **kwargs,
  319. )
  320. def preprocess(
  321. self,
  322. images: ImageInput,
  323. do_resize: bool = None,
  324. size: Dict[str, int] = None,
  325. patch_size: Dict[str, int] = None,
  326. resample: PILImageResampling = None,
  327. do_rescale: bool = None,
  328. rescale_factor: float = None,
  329. do_normalize: bool = None,
  330. image_mean: Optional[Union[float, List[float]]] = None,
  331. image_std: Optional[Union[float, List[float]]] = None,
  332. do_convert_rgb: bool = None,
  333. return_tensors: Optional[Union[str, TensorType]] = None,
  334. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  335. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  336. **kwargs,
  337. ) -> PIL.Image.Image:
  338. """
  339. Preprocess an image or batch of images.
  340. Args:
  341. images (`ImageInput`):
  342. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  343. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  344. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  345. Whether to resize the image.
  346. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  347. Describes the maximum input dimensions to the model.
  348. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
  349. Patch size in the model. Used to calculate the image after resizing.
  350. resample (`int`, *optional*, defaults to `self.resample`):
  351. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
  352. has an effect if `do_resize` is set to `True`.
  353. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  354. Whether to rescale the image.
  355. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  356. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  357. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  358. Whether to normalize the image.
  359. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  360. Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
  361. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  362. Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
  363. `True`.
  364. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  365. Whether to convert the image to RGB.
  366. return_tensors (`str` or `TensorType`, *optional*):
  367. The type of tensors to return. Can be one of:
  368. - Unset: Return a list of `np.ndarray`.
  369. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  370. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  371. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  372. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  373. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  374. The channel dimension format for the output image. Can be one of:
  375. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  376. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  377. - Unset: Use the channel dimension format of the input image.
  378. input_data_format (`ChannelDimension` or `str`, *optional*):
  379. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  380. from the input image. Can be one of:
  381. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  382. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  383. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  384. """
  385. patch_size = patch_size if patch_size is not None else self.patch_size
  386. patch_size = get_size_dict(patch_size, default_to_square=True)
  387. do_resize = do_resize if do_resize is not None else self.do_resize
  388. size = size if size is not None else self.size
  389. resample = resample if resample is not None else self.resample
  390. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  391. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  392. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  393. image_mean = image_mean if image_mean is not None else self.image_mean
  394. image_std = image_std if image_std is not None else self.image_std
  395. do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  396. validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
  397. images_list = make_list_of_images(images)
  398. if not valid_images(images_list[0]):
  399. raise ValueError(
  400. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  401. "torch.Tensor, tf.Tensor or jax.ndarray."
  402. )
  403. validate_preprocess_arguments(
  404. do_rescale=do_rescale,
  405. rescale_factor=rescale_factor,
  406. do_normalize=do_normalize,
  407. image_mean=image_mean,
  408. image_std=image_std,
  409. do_resize=do_resize,
  410. size=size,
  411. resample=resample,
  412. )
  413. if do_convert_rgb:
  414. images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
  415. # All transformations expect numpy arrays.
  416. images_list = [[to_numpy_array(image) for image in images] for images in images_list]
  417. if is_scaled_image(images_list[0][0]) and do_rescale:
  418. logger.warning_once(
  419. "It looks like you are trying to rescale already rescaled images. If the input"
  420. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  421. )
  422. if input_data_format is None:
  423. # We assume that all images have the same channel dimension format.
  424. input_data_format = infer_channel_dimension_format(images_list[0][0])
  425. batch_images = []
  426. batch_image_sizes = []
  427. for sample_images in images_list:
  428. images = []
  429. image_sizes = []
  430. for image in sample_images:
  431. if do_resize:
  432. image = self.resize(
  433. image=image,
  434. size=size,
  435. patch_size=patch_size,
  436. resample=resample,
  437. input_data_format=input_data_format,
  438. )
  439. if do_rescale:
  440. image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  441. if do_normalize:
  442. image = self.normalize(
  443. image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
  444. )
  445. images.append(image)
  446. image_sizes.append(get_image_size(image, input_data_format))
  447. batch_images.append(images)
  448. batch_image_sizes.append(image_sizes)
  449. images_list = [
  450. [to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images]
  451. for images in batch_images
  452. ]
  453. # Convert to tensor type outside of BatchFeature to avoid batching the images of different sizes
  454. images_list = [[convert_to_tensor(image, return_tensors) for image in images] for images in images_list]
  455. return BatchMixFeature(data={"pixel_values": images_list, "image_sizes": batch_image_sizes}, tensor_type=None)