image_processing_flava.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Flava."""
  16. import math
  17. import random
  18. from functools import lru_cache
  19. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  20. import numpy as np
  21. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  22. from ...image_transforms import resize, to_channel_dimension_format
  23. from ...image_utils import (
  24. OPENAI_CLIP_MEAN,
  25. OPENAI_CLIP_STD,
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. infer_channel_dimension_format,
  30. is_scaled_image,
  31. make_list_of_images,
  32. to_numpy_array,
  33. valid_images,
  34. validate_preprocess_arguments,
  35. )
  36. from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
  37. if is_vision_available():
  38. import PIL
  39. logger = logging.get_logger(__name__)
  40. # These values are taken from CLIP
  41. FLAVA_IMAGE_MEAN = OPENAI_CLIP_MEAN
  42. FLAVA_IMAGE_STD = OPENAI_CLIP_STD
  43. FLAVA_CODEBOOK_MEAN = [0.0, 0.0, 0.0]
  44. FLAVA_CODEBOOK_STD = [1.0, 1.0, 1.0]
  45. LOGIT_LAPLACE_EPS: float = 0.1
  46. # Inspired from https://github.com/microsoft/unilm/blob/master/beit/masking_generator.py
  47. class FlavaMaskingGenerator:
  48. def __init__(
  49. self,
  50. input_size: Union[int, Tuple[int, int]] = 14,
  51. total_mask_patches: int = 75,
  52. mask_group_max_patches: Optional[int] = None,
  53. mask_group_min_patches: int = 16,
  54. mask_group_min_aspect_ratio: Optional[float] = 0.3,
  55. mask_group_max_aspect_ratio: float = None,
  56. ):
  57. if not isinstance(input_size, tuple):
  58. input_size = (input_size,) * 2
  59. self.height, self.width = input_size
  60. self.num_patches = self.height * self.width
  61. self.total_mask_patches = total_mask_patches
  62. self.mask_group_min_patches = mask_group_min_patches
  63. self.mask_group_max_patches = total_mask_patches if mask_group_max_patches is None else mask_group_max_patches
  64. mask_group_max_aspect_ratio = mask_group_max_aspect_ratio or 1 / mask_group_min_aspect_ratio
  65. self.log_aspect_ratio = (math.log(mask_group_min_aspect_ratio), math.log(mask_group_max_aspect_ratio))
  66. def __repr__(self):
  67. repr_str = "MaskingGenerator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
  68. self.height,
  69. self.width,
  70. self.mask_group_min_patches,
  71. self.mask_group_max_patches,
  72. self.total_mask_patches,
  73. self.log_aspect_ratio[0],
  74. self.log_aspect_ratio[1],
  75. )
  76. return repr_str
  77. def get_shape(self):
  78. return self.height, self.width
  79. def _mask(self, mask, max_mask_patches):
  80. delta = 0
  81. for _attempt in range(10):
  82. target_area = random.uniform(self.mask_group_min_patches, max_mask_patches)
  83. aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
  84. height = int(round(math.sqrt(target_area * aspect_ratio)))
  85. width = int(round(math.sqrt(target_area / aspect_ratio)))
  86. if width < self.width and height < self.height:
  87. top = random.randint(0, self.height - height)
  88. left = random.randint(0, self.width - width)
  89. num_masked = mask[top : top + height, left : left + width].sum()
  90. # Overlap
  91. if 0 < height * width - num_masked <= max_mask_patches:
  92. for i in range(top, top + height):
  93. for j in range(left, left + width):
  94. if mask[i, j] == 0:
  95. mask[i, j] = 1
  96. delta += 1
  97. if delta > 0:
  98. break
  99. return delta
  100. def __call__(self):
  101. mask = np.zeros(shape=self.get_shape(), dtype=int)
  102. mask_count = 0
  103. while mask_count < self.total_mask_patches:
  104. max_mask_patches = self.total_mask_patches - mask_count
  105. max_mask_patches = min(max_mask_patches, self.mask_group_max_patches)
  106. delta = self._mask(mask, max_mask_patches)
  107. if delta == 0:
  108. break
  109. else:
  110. mask_count += delta
  111. return mask
  112. class FlavaImageProcessor(BaseImageProcessor):
  113. r"""
  114. Constructs a Flava image processor.
  115. Args:
  116. do_resize (`bool`, *optional*, defaults to `True`):
  117. Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
  118. `do_resize` parameter in `preprocess`.
  119. size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
  120. Size of the image after resizing. Can be overridden by the `size` parameter in `preprocess`.
  121. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  122. Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in
  123. `preprocess`.
  124. do_center_crop (`bool`, *optional*, defaults to `True`):
  125. Whether to center crop the images. Can be overridden by the `do_center_crop` parameter in `preprocess`.
  126. crop_size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
  127. Size of image after the center crop `(crop_size["height"], crop_size["width"])`. Can be overridden by the
  128. `crop_size` parameter in `preprocess`.
  129. do_rescale (`bool`, *optional*, defaults to `True`):
  130. Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
  131. parameter in `preprocess`.
  132. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  133. Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in
  134. `preprocess`.
  135. do_normalize (`bool`, *optional*, defaults to `True`):
  136. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in `preprocess`.
  137. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
  138. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  139. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  140. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
  141. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  142. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  143. return_image_mask (`bool`, *optional*, defaults to `False`):
  144. Whether to return the image mask. Can be overridden by the `return_image_mask` parameter in `preprocess`.
  145. input_size_patches (`int`, *optional*, defaults to 14):
  146. Number of patches in the image in height and width direction. 14x14 = 196 total patches. Can be overridden
  147. by the `input_size_patches` parameter in `preprocess`.
  148. total_mask_patches (`int`, *optional*, defaults to 75):
  149. Total number of patches that should be masked. Can be overridden by the `total_mask_patches` parameter in
  150. `preprocess`.
  151. mask_group_min_patches (`int`, *optional*, defaults to 16):
  152. Minimum number of patches that should be masked. Can be overridden by the `mask_group_min_patches`
  153. parameter in `preprocess`.
  154. mask_group_max_patches (`int`, *optional*):
  155. Maximum number of patches that should be masked. Can be overridden by the `mask_group_max_patches`
  156. parameter in `preprocess`.
  157. mask_group_min_aspect_ratio (`float`, *optional*, defaults to 0.3):
  158. Minimum aspect ratio of the mask window. Can be overridden by the `mask_group_min_aspect_ratio` parameter
  159. in `preprocess`.
  160. mask_group_max_aspect_ratio (`float`, *optional*):
  161. Maximum aspect ratio of the mask window. Can be overridden by the `mask_group_max_aspect_ratio` parameter
  162. in `preprocess`.
  163. codebook_do_resize (`bool`, *optional*, defaults to `True`):
  164. Whether to resize the input for codebook to a certain. Can be overridden by the `codebook_do_resize`
  165. parameter in `preprocess`. `codebook_size`.
  166. codebook_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
  167. Resize the input for codebook to the given size. Can be overridden by the `codebook_size` parameter in
  168. `preprocess`.
  169. codebook_resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
  170. Resampling filter to use if resizing the codebook image. Can be overridden by the `codebook_resample`
  171. parameter in `preprocess`.
  172. codebook_do_center_crop (`bool`, *optional*, defaults to `True`):
  173. Whether to crop the input for codebook at the center. If the input size is smaller than
  174. `codebook_crop_size` along any edge, the image is padded with 0's and then center cropped. Can be
  175. overridden by the `codebook_do_center_crop` parameter in `preprocess`.
  176. codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
  177. Desired output size for codebook input when applying center-cropping. Can be overridden by the
  178. `codebook_crop_size` parameter in `preprocess`.
  179. codebook_do_rescale (`bool`, *optional*, defaults to `True`):
  180. Whether to rescale the input for codebook by the specified scale `codebook_rescale_factor`. Can be
  181. overridden by the `codebook_do_rescale` parameter in `preprocess`.
  182. codebook_rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  183. Defines the scale factor to use if rescaling the codebook image. Can be overridden by the
  184. `codebook_rescale_factor` parameter in `preprocess`.
  185. codebook_do_map_pixels (`bool`, *optional*, defaults to `True`):
  186. Whether to map the pixel values of the codebook input to (1 - 2e)x + e. Can be overridden by the
  187. `codebook_do_map_pixels` parameter in `preprocess`.
  188. codebook_do_normalize (`bool`, *optional*, defaults to `True`):
  189. Whether or not to normalize the input for codebook with `codebook_image_mean` and `codebook_image_std`. Can
  190. be overridden by the `codebook_do_normalize` parameter in `preprocess`.
  191. codebook_image_mean (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0, 0, 0]`):
  192. The sequence of means for each channel, to be used when normalizing images for codebook. Can be overridden
  193. by the `codebook_image_mean` parameter in `preprocess`.
  194. codebook_image_std (`Optional[Union[float, Iterable[float]]]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
  195. The sequence of standard deviations for each channel, to be used when normalizing images for codebook. Can
  196. be overridden by the `codebook_image_std` parameter in `preprocess`.
  197. """
  198. model_input_names = ["pixel_values"]
  199. def __init__(
  200. self,
  201. do_resize: bool = True,
  202. size: Dict[str, int] = None,
  203. resample: PILImageResampling = PILImageResampling.BICUBIC,
  204. do_center_crop: bool = True,
  205. crop_size: Dict[str, int] = None,
  206. do_rescale: bool = True,
  207. rescale_factor: Union[int, float] = 1 / 255,
  208. do_normalize: bool = True,
  209. image_mean: Optional[Union[float, Iterable[float]]] = None,
  210. image_std: Optional[Union[float, Iterable[float]]] = None,
  211. # Mask related params
  212. return_image_mask: bool = False,
  213. input_size_patches: int = 14,
  214. total_mask_patches: int = 75,
  215. mask_group_min_patches: int = 16,
  216. mask_group_max_patches: Optional[int] = None,
  217. mask_group_min_aspect_ratio: float = 0.3,
  218. mask_group_max_aspect_ratio: Optional[float] = None,
  219. # Codebook related params
  220. return_codebook_pixels: bool = False,
  221. codebook_do_resize: bool = True,
  222. codebook_size: bool = None,
  223. codebook_resample: int = PILImageResampling.LANCZOS,
  224. codebook_do_center_crop: bool = True,
  225. codebook_crop_size: int = None,
  226. codebook_do_rescale: bool = True,
  227. codebook_rescale_factor: Union[int, float] = 1 / 255,
  228. codebook_do_map_pixels: bool = True,
  229. codebook_do_normalize: bool = True,
  230. codebook_image_mean: Optional[Union[float, Iterable[float]]] = None,
  231. codebook_image_std: Optional[Union[float, Iterable[float]]] = None,
  232. **kwargs,
  233. ) -> None:
  234. super().__init__(**kwargs)
  235. size = size if size is not None else {"height": 224, "width": 224}
  236. size = get_size_dict(size)
  237. crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
  238. crop_size = get_size_dict(crop_size, param_name="crop_size")
  239. codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
  240. codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
  241. codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
  242. codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
  243. self.do_resize = do_resize
  244. self.size = size
  245. self.resample = resample
  246. self.do_rescale = do_rescale
  247. self.rescale_factor = rescale_factor
  248. self.do_center_crop = do_center_crop
  249. self.crop_size = crop_size
  250. self.do_normalize = do_normalize
  251. self.image_mean = image_mean if image_mean is not None else FLAVA_IMAGE_MEAN
  252. self.image_std = image_std if image_std is not None else FLAVA_IMAGE_STD
  253. self.return_image_mask = return_image_mask
  254. self.input_size_patches = input_size_patches
  255. self.total_mask_patches = total_mask_patches
  256. self.mask_group_min_patches = mask_group_min_patches
  257. self.mask_group_max_patches = mask_group_max_patches
  258. self.mask_group_min_aspect_ratio = mask_group_min_aspect_ratio
  259. self.mask_group_max_aspect_ratio = mask_group_max_aspect_ratio
  260. self.return_codebook_pixels = return_codebook_pixels
  261. self.codebook_do_resize = codebook_do_resize
  262. self.codebook_size = codebook_size
  263. self.codebook_resample = codebook_resample
  264. self.codebook_do_center_crop = codebook_do_center_crop
  265. self.codebook_crop_size = codebook_crop_size
  266. self.codebook_do_rescale = codebook_do_rescale
  267. self.codebook_rescale_factor = codebook_rescale_factor
  268. self.codebook_do_map_pixels = codebook_do_map_pixels
  269. self.codebook_do_normalize = codebook_do_normalize
  270. self.codebook_image_mean = codebook_image_mean
  271. self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
  272. self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
  273. @classmethod
  274. def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
  275. """
  276. Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
  277. created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
  278. """
  279. image_processor_dict = image_processor_dict.copy()
  280. if "codebook_size" in kwargs:
  281. image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
  282. if "codebook_crop_size" in kwargs:
  283. image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
  284. return super().from_dict(image_processor_dict, **kwargs)
  285. @lru_cache()
  286. def masking_generator(
  287. self,
  288. input_size_patches,
  289. total_mask_patches,
  290. mask_group_min_patches,
  291. mask_group_max_patches,
  292. mask_group_min_aspect_ratio,
  293. mask_group_max_aspect_ratio,
  294. ) -> FlavaMaskingGenerator:
  295. return FlavaMaskingGenerator(
  296. input_size=input_size_patches,
  297. total_mask_patches=total_mask_patches,
  298. mask_group_min_patches=mask_group_min_patches,
  299. mask_group_max_patches=mask_group_max_patches,
  300. mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
  301. mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
  302. )
  303. # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
  304. def resize(
  305. self,
  306. image: np.ndarray,
  307. size: Dict[str, int],
  308. resample: PILImageResampling = PILImageResampling.BICUBIC,
  309. data_format: Optional[Union[str, ChannelDimension]] = None,
  310. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  311. **kwargs,
  312. ) -> np.ndarray:
  313. """
  314. Resize an image to `(size["height"], size["width"])`.
  315. Args:
  316. image (`np.ndarray`):
  317. Image to resize.
  318. size (`Dict[str, int]`):
  319. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  320. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  321. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
  322. data_format (`ChannelDimension` or `str`, *optional*):
  323. The channel dimension format for the output image. If unset, the channel dimension format of the input
  324. image is used. Can be one of:
  325. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  326. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  327. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  328. input_data_format (`ChannelDimension` or `str`, *optional*):
  329. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  330. from the input image. Can be one of:
  331. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  332. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  333. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  334. Returns:
  335. `np.ndarray`: The resized image.
  336. """
  337. size = get_size_dict(size)
  338. if "height" not in size or "width" not in size:
  339. raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
  340. output_size = (size["height"], size["width"])
  341. return resize(
  342. image,
  343. size=output_size,
  344. resample=resample,
  345. data_format=data_format,
  346. input_data_format=input_data_format,
  347. **kwargs,
  348. )
  349. def map_pixels(self, image: np.ndarray) -> np.ndarray:
  350. return (1 - 2 * LOGIT_LAPLACE_EPS) * image + LOGIT_LAPLACE_EPS
  351. def _preprocess_image(
  352. self,
  353. image: ImageInput,
  354. do_resize: bool = None,
  355. size: Dict[str, int] = None,
  356. resample: PILImageResampling = None,
  357. do_center_crop: bool = None,
  358. crop_size: Dict[str, int] = None,
  359. do_rescale: bool = None,
  360. rescale_factor: float = None,
  361. do_normalize: bool = None,
  362. image_mean: Optional[Union[float, List[float]]] = None,
  363. image_std: Optional[Union[float, List[float]]] = None,
  364. do_map_pixels: bool = None,
  365. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  366. input_data_format: Optional[ChannelDimension] = None,
  367. ) -> np.ndarray:
  368. """Preprocesses a single image."""
  369. validate_preprocess_arguments(
  370. do_rescale=do_rescale,
  371. rescale_factor=rescale_factor,
  372. do_normalize=do_normalize,
  373. image_mean=image_mean,
  374. image_std=image_std,
  375. do_center_crop=do_center_crop,
  376. crop_size=crop_size,
  377. do_resize=do_resize,
  378. size=size,
  379. resample=resample,
  380. )
  381. # All transformations expect numpy arrays.
  382. image = to_numpy_array(image)
  383. if is_scaled_image(image) and do_rescale:
  384. logger.warning_once(
  385. "It looks like you are trying to rescale already rescaled images. If the input"
  386. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  387. )
  388. if input_data_format is None:
  389. # We assume that all images have the same channel dimension format.
  390. input_data_format = infer_channel_dimension_format(image)
  391. if do_resize:
  392. image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
  393. if do_center_crop:
  394. image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
  395. if do_rescale:
  396. image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  397. if do_normalize:
  398. image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  399. if do_map_pixels:
  400. image = self.map_pixels(image)
  401. if data_format is not None:
  402. image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
  403. return image
  404. @filter_out_non_signature_kwargs()
  405. def preprocess(
  406. self,
  407. images: ImageInput,
  408. do_resize: Optional[bool] = None,
  409. size: Dict[str, int] = None,
  410. resample: PILImageResampling = None,
  411. do_center_crop: Optional[bool] = None,
  412. crop_size: Optional[Dict[str, int]] = None,
  413. do_rescale: Optional[bool] = None,
  414. rescale_factor: Optional[float] = None,
  415. do_normalize: Optional[bool] = None,
  416. image_mean: Optional[Union[float, List[float]]] = None,
  417. image_std: Optional[Union[float, List[float]]] = None,
  418. # Mask related params
  419. return_image_mask: Optional[bool] = None,
  420. input_size_patches: Optional[int] = None,
  421. total_mask_patches: Optional[int] = None,
  422. mask_group_min_patches: Optional[int] = None,
  423. mask_group_max_patches: Optional[int] = None,
  424. mask_group_min_aspect_ratio: Optional[float] = None,
  425. mask_group_max_aspect_ratio: Optional[float] = None,
  426. # Codebook related params
  427. return_codebook_pixels: Optional[bool] = None,
  428. codebook_do_resize: Optional[bool] = None,
  429. codebook_size: Optional[Dict[str, int]] = None,
  430. codebook_resample: Optional[int] = None,
  431. codebook_do_center_crop: Optional[bool] = None,
  432. codebook_crop_size: Optional[Dict[str, int]] = None,
  433. codebook_do_rescale: Optional[bool] = None,
  434. codebook_rescale_factor: Optional[float] = None,
  435. codebook_do_map_pixels: Optional[bool] = None,
  436. codebook_do_normalize: Optional[bool] = None,
  437. codebook_image_mean: Optional[Iterable[float]] = None,
  438. codebook_image_std: Optional[Iterable[float]] = None,
  439. return_tensors: Optional[Union[str, TensorType]] = None,
  440. data_format: ChannelDimension = ChannelDimension.FIRST,
  441. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  442. ) -> PIL.Image.Image:
  443. """
  444. Preprocess an image or batch of images.
  445. Args:
  446. images (`ImageInput`):
  447. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  448. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  449. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  450. Whether to resize the image.
  451. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  452. Size of the image.
  453. resample (`int`, *optional*, defaults to `self.resample`):
  454. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
  455. has an effect if `do_resize` is set to `True`.
  456. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
  457. Whether to center crop the image.
  458. crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
  459. Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
  460. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  461. Whether to rescale the image values between [0 - 1].
  462. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  463. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  464. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  465. Whether to normalize the image.
  466. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  467. Image mean.
  468. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  469. Image standard deviation.
  470. return_image_mask (`bool`, *optional*, defaults to `self.return_image_mask`):
  471. Whether to return the image mask.
  472. input_size_patches (`int`, *optional*, defaults to `self.input_size_patches`):
  473. Size of the patches to extract from the image.
  474. total_mask_patches (`int`, *optional*, defaults to `self.total_mask_patches`):
  475. Total number of patches to extract from the image.
  476. mask_group_min_patches (`int`, *optional*, defaults to `self.mask_group_min_patches`):
  477. Minimum number of patches to extract from the image.
  478. mask_group_max_patches (`int`, *optional*, defaults to `self.mask_group_max_patches`):
  479. Maximum number of patches to extract from the image.
  480. mask_group_min_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_min_aspect_ratio`):
  481. Minimum aspect ratio of the patches to extract from the image.
  482. mask_group_max_aspect_ratio (`float`, *optional*, defaults to `self.mask_group_max_aspect_ratio`):
  483. Maximum aspect ratio of the patches to extract from the image.
  484. return_codebook_pixels (`bool`, *optional*, defaults to `self.return_codebook_pixels`):
  485. Whether to return the codebook pixels.
  486. codebook_do_resize (`bool`, *optional*, defaults to `self.codebook_do_resize`):
  487. Whether to resize the codebook pixels.
  488. codebook_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_size`):
  489. Size of the codebook pixels.
  490. codebook_resample (`int`, *optional*, defaults to `self.codebook_resample`):
  491. Resampling filter to use if resizing the codebook pixels. This can be one of the enum
  492. `PILImageResampling`, Only has an effect if `codebook_do_resize` is set to `True`.
  493. codebook_do_center_crop (`bool`, *optional*, defaults to `self.codebook_do_center_crop`):
  494. Whether to center crop the codebook pixels.
  495. codebook_crop_size (`Dict[str, int]`, *optional*, defaults to `self.codebook_crop_size`):
  496. Size of the center crop of the codebook pixels. Only has an effect if `codebook_do_center_crop` is set
  497. to `True`.
  498. codebook_do_rescale (`bool`, *optional*, defaults to `self.codebook_do_rescale`):
  499. Whether to rescale the codebook pixels values between [0 - 1].
  500. codebook_rescale_factor (`float`, *optional*, defaults to `self.codebook_rescale_factor`):
  501. Rescale factor to rescale the codebook pixels by if `codebook_do_rescale` is set to `True`.
  502. codebook_do_map_pixels (`bool`, *optional*, defaults to `self.codebook_do_map_pixels`):
  503. Whether to map the codebook pixels values.
  504. codebook_do_normalize (`bool`, *optional*, defaults to `self.codebook_do_normalize`):
  505. Whether to normalize the codebook pixels.
  506. codebook_image_mean (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_mean`):
  507. Codebook pixels mean to normalize the codebook pixels by if `codebook_do_normalize` is set to `True`.
  508. codebook_image_std (`float` or `List[float]`, *optional*, defaults to `self.codebook_image_std`):
  509. Codebook pixels standard deviation to normalize the codebook pixels by if `codebook_do_normalize` is
  510. set to `True`.
  511. return_tensors (`str` or `TensorType`, *optional*):
  512. The type of tensors to return. Can be one of:
  513. - Unset: Return a list of `np.ndarray`.
  514. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  515. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  516. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  517. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  518. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  519. The channel dimension format for the output image. Can be one of:
  520. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  521. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  522. input_data_format (`ChannelDimension` or `str`, *optional*):
  523. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  524. from the input image. Can be one of:
  525. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  526. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  527. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  528. """
  529. do_resize = do_resize if do_resize is not None else self.do_resize
  530. size = size if size is not None else self.size
  531. size = get_size_dict(size)
  532. resample = resample if resample is not None else self.resample
  533. do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
  534. crop_size = crop_size if crop_size is not None else self.crop_size
  535. crop_size = get_size_dict(crop_size, param_name="crop_size")
  536. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  537. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  538. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  539. image_mean = image_mean if image_mean is not None else self.image_mean
  540. image_std = image_std if image_std is not None else self.image_std
  541. return_image_mask = return_image_mask if return_image_mask is not None else self.return_image_mask
  542. input_size_patches = input_size_patches if input_size_patches is not None else self.input_size_patches
  543. total_mask_patches = total_mask_patches if total_mask_patches is not None else self.total_mask_patches
  544. mask_group_min_patches = (
  545. mask_group_min_patches if mask_group_min_patches is not None else self.mask_group_min_patches
  546. )
  547. mask_group_max_patches = (
  548. mask_group_max_patches if mask_group_max_patches is not None else self.mask_group_max_patches
  549. )
  550. mask_group_min_aspect_ratio = (
  551. mask_group_min_aspect_ratio
  552. if mask_group_min_aspect_ratio is not None
  553. else self.mask_group_min_aspect_ratio
  554. )
  555. mask_group_max_aspect_ratio = (
  556. mask_group_max_aspect_ratio
  557. if mask_group_max_aspect_ratio is not None
  558. else self.mask_group_max_aspect_ratio
  559. )
  560. return_codebook_pixels = (
  561. return_codebook_pixels if return_codebook_pixels is not None else self.return_codebook_pixels
  562. )
  563. codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
  564. codebook_size = codebook_size if codebook_size is not None else self.codebook_size
  565. codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
  566. codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
  567. codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
  568. codebook_rescale_factor = (
  569. codebook_rescale_factor if codebook_rescale_factor is not None else self.codebook_rescale_factor
  570. )
  571. codebook_do_center_crop = (
  572. codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
  573. )
  574. codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
  575. codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
  576. codebook_do_map_pixels = (
  577. codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
  578. )
  579. codebook_do_normalize = (
  580. codebook_do_normalize if codebook_do_normalize is not None else self.codebook_do_normalize
  581. )
  582. codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else self.codebook_image_mean
  583. codebook_image_std = codebook_image_std if codebook_image_std is not None else self.codebook_image_std
  584. images = make_list_of_images(images)
  585. if not valid_images(images):
  586. raise ValueError(
  587. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  588. "torch.Tensor, tf.Tensor or jax.ndarray."
  589. )
  590. processed_images = [
  591. self._preprocess_image(
  592. image=img,
  593. do_resize=do_resize,
  594. size=size,
  595. resample=resample,
  596. do_center_crop=do_center_crop,
  597. crop_size=crop_size,
  598. do_rescale=do_rescale,
  599. rescale_factor=rescale_factor,
  600. do_normalize=do_normalize,
  601. image_mean=image_mean,
  602. image_std=image_std,
  603. do_map_pixels=False,
  604. data_format=data_format,
  605. input_data_format=input_data_format,
  606. )
  607. for img in images
  608. ]
  609. data = {"pixel_values": processed_images}
  610. if return_codebook_pixels:
  611. codebook_images = [
  612. self._preprocess_image(
  613. image=img,
  614. do_resize=codebook_do_resize,
  615. size=codebook_size,
  616. resample=codebook_resample,
  617. do_center_crop=codebook_do_center_crop,
  618. crop_size=codebook_crop_size,
  619. do_rescale=codebook_do_rescale,
  620. rescale_factor=codebook_rescale_factor,
  621. do_normalize=codebook_do_normalize,
  622. image_mean=codebook_image_mean,
  623. image_std=codebook_image_std,
  624. do_map_pixels=codebook_do_map_pixels,
  625. data_format=data_format,
  626. input_data_format=input_data_format,
  627. )
  628. for img in images
  629. ]
  630. data["codebook_pixel_values"] = codebook_images
  631. if return_image_mask:
  632. mask_generator = self.masking_generator(
  633. input_size_patches=input_size_patches,
  634. total_mask_patches=total_mask_patches,
  635. mask_group_min_patches=mask_group_min_patches,
  636. mask_group_max_patches=mask_group_max_patches,
  637. mask_group_min_aspect_ratio=mask_group_min_aspect_ratio,
  638. mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
  639. )
  640. masks = [mask_generator() for _ in images]
  641. data["bool_masked_pos"] = masks
  642. return BatchFeature(data=data, tensor_type=return_tensors)