image_processing_fuyu.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Fuyu."""
  16. import math
  17. from typing import Dict, List, Optional, Union
  18. import numpy as np
  19. from ...image_processing_utils import BaseImageProcessor, BatchFeature
  20. from ...image_transforms import (
  21. pad,
  22. resize,
  23. to_channel_dimension_format,
  24. )
  25. from ...image_utils import (
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. get_image_size,
  30. infer_channel_dimension_format,
  31. is_scaled_image,
  32. is_valid_image,
  33. make_list_of_images,
  34. to_numpy_array,
  35. validate_preprocess_arguments,
  36. )
  37. from ...utils import (
  38. TensorType,
  39. filter_out_non_signature_kwargs,
  40. is_torch_available,
  41. is_torch_device,
  42. is_torch_dtype,
  43. logging,
  44. requires_backends,
  45. )
  46. if is_torch_available():
  47. import torch
  48. logger = logging.get_logger(__name__)
  49. def make_list_of_list_of_images(
  50. images: Union[List[List[ImageInput]], List[ImageInput], ImageInput],
  51. ) -> List[List[ImageInput]]:
  52. if is_valid_image(images):
  53. return [[images]]
  54. if isinstance(images, list) and all(isinstance(image, list) for image in images):
  55. return images
  56. if isinstance(images, list):
  57. return [make_list_of_images(image) for image in images]
  58. raise ValueError("images must be a list of list of images or a list of images or an image.")
  59. class FuyuBatchFeature(BatchFeature):
  60. """
  61. BatchFeature class for Fuyu image processor and processor.
  62. The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
  63. """
  64. def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
  65. """
  66. Convert the inner content to tensors.
  67. Args:
  68. tensor_type (`str` or [`~utils.TensorType`], *optional*):
  69. The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
  70. `None`, no modification is done.
  71. """
  72. if tensor_type is None:
  73. return self
  74. is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type=tensor_type)
  75. def _convert_tensor(elem):
  76. if is_tensor(elem):
  77. return elem
  78. return as_tensor(elem)
  79. def _safe_convert_tensor(elem):
  80. try:
  81. return _convert_tensor(elem)
  82. except: # noqa E722
  83. if key == "overflowing_values":
  84. raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
  85. raise ValueError(
  86. "Unable to create tensor, you should probably activate padding "
  87. "with 'padding=True' to have batched tensors with the same length."
  88. )
  89. # Do the tensor conversion in batch
  90. for key, value in self.items():
  91. if isinstance(value, list) and isinstance(value[0], list):
  92. # List[List[Any]] -> List[List[Tensor]]
  93. self[key] = [[_safe_convert_tensor(elem) for elem in elems] for elems in value]
  94. elif isinstance(value, list):
  95. # List[Any] -> List[Tensor]
  96. self[key] = [_safe_convert_tensor(elem) for elem in value]
  97. else:
  98. # Any -> Tensor
  99. self[key] = _safe_convert_tensor(value)
  100. return self
  101. def to(self, *args, **kwargs) -> "BatchFeature":
  102. """
  103. Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
  104. different `dtypes` and sending the `BatchFeature` to a different `device`.
  105. Args:
  106. args (`Tuple`):
  107. Will be passed to the `to(...)` function of the tensors.
  108. kwargs (`Dict`, *optional*):
  109. Will be passed to the `to(...)` function of the tensors.
  110. Returns:
  111. [`BatchFeature`]: The same instance after modification.
  112. """
  113. requires_backends(self, ["torch"])
  114. import torch # noqa
  115. new_data = {}
  116. device = kwargs.get("device")
  117. # Check if the args are a device or a dtype
  118. if device is None and len(args) > 0:
  119. # device should be always the first argument
  120. arg = args[0]
  121. if is_torch_dtype(arg):
  122. # The first argument is a dtype
  123. pass
  124. elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
  125. device = arg
  126. else:
  127. # it's something else
  128. raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
  129. def _to(elem):
  130. # check if v is a floating point
  131. if torch.is_floating_point(elem):
  132. # cast and send to device
  133. return elem.to(*args, **kwargs)
  134. if device is not None:
  135. return elem.to(device=device)
  136. return elem
  137. # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
  138. for k, v in self.items():
  139. if isinstance(v, list) and isinstance(v[0], list):
  140. # Data structure is a list of lists
  141. new_v = []
  142. for elems in v:
  143. new_v.append([_to(elem) for elem in elems])
  144. new_data[k] = new_v
  145. elif isinstance(v, list):
  146. # Data structure is a list
  147. new_data[k] = [_to(elem) for elem in v]
  148. else:
  149. new_data[k] = _to(v)
  150. self.data = new_data
  151. return self
  152. class FuyuImageProcessor(BaseImageProcessor):
  153. """
  154. This class should handle the image processing part before the main FuyuForCausalLM. In particular, it should
  155. handle:
  156. - Processing Images:
  157. Taking a batch of images as input. If the images are variable-sized, it resizes them based on the desired patch
  158. dimensions. The image output is always img_h, img_w of (1080, 1920)
  159. Then, it patches up these images using the patchify_image function.
  160. - Creating Image Input IDs:
  161. For each patch, a placeholder ID is given to identify where these patches belong in a token sequence. For
  162. variable-sized images, each line of patches is terminated with a newline ID.
  163. - Image Patch Indices:
  164. For each image patch, the code maintains an index where these patches should be inserted in a token stream.
  165. Args:
  166. do_resize (`bool`, *optional*, defaults to `True`):
  167. Whether to resize the image to `size`.
  168. size (`Dict[str, int]`, *optional*, defaults to `{"height": 1080, "width": 1920}`):
  169. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  170. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
  171. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
  172. do_pad (`bool`, *optional*, defaults to `True`):
  173. Whether to pad the image to `size`.
  174. padding_value (`float`, *optional*, defaults to 1.0):
  175. The value to pad the image with.
  176. padding_mode (`str`, *optional*, defaults to `"constant"`):
  177. The padding mode to use when padding the image.
  178. do_normalize (`bool`, *optional*, defaults to `True`):
  179. Whether to normalize the image.
  180. image_mean (`float`, *optional*, defaults to 0.5):
  181. The mean to use when normalizing the image.
  182. image_std (`float`, *optional*, defaults to 0.5):
  183. The standard deviation to use when normalizing the image.
  184. do_rescale (`bool`, *optional*, defaults to `True`):
  185. Whether to rescale the image.
  186. rescale_factor (`float`, *optional*, defaults to `1 / 255`):
  187. The factor to use when rescaling the image.
  188. patch_size (`Dict[str, int]`, *optional*, defaults to `{"height": 30, "width": 30}`):
  189. Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
  190. """
  191. model_input_names = [
  192. "images",
  193. "image_input_ids",
  194. "image_patches",
  195. "image_patch_indices_per_batch",
  196. "image_patch_indices_per_subsequence",
  197. ]
  198. def __init__(
  199. self,
  200. do_resize: bool = True,
  201. size: Optional[Dict[str, int]] = None,
  202. resample: PILImageResampling = PILImageResampling.BILINEAR,
  203. do_pad: bool = True,
  204. padding_value: float = 1.0,
  205. padding_mode: str = "constant",
  206. do_normalize: bool = True,
  207. image_mean: Union[float, List[float]] = 0.5,
  208. image_std: Union[float, List[float]] = 0.5,
  209. do_rescale: bool = True,
  210. rescale_factor: float = 1 / 255,
  211. patch_size: Optional[Dict[str, int]] = None,
  212. **kwargs,
  213. ):
  214. super().__init__(**kwargs)
  215. self.do_resize = do_resize
  216. self.size = size if size is not None else {"height": 1080, "width": 1920}
  217. self.resample = resample
  218. self.do_pad = do_pad
  219. self.padding_value = padding_value
  220. self.padding_mode = padding_mode
  221. self.do_normalize = do_normalize
  222. self.image_mean = image_mean
  223. self.image_std = image_std
  224. self.do_rescale = do_rescale
  225. self.rescale_factor = rescale_factor
  226. self.patch_size = patch_size if patch_size is not None else {"height": 30, "width": 30}
  227. def resize(
  228. self,
  229. image: np.ndarray,
  230. size: Dict[str, int],
  231. resample: PILImageResampling = PILImageResampling.BILINEAR,
  232. data_format: Optional[Union[str, ChannelDimension]] = None,
  233. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  234. **kwargs,
  235. ) -> np.ndarray:
  236. """
  237. Resize an image to `(size["height"], size["width"])`.
  238. Args:
  239. image (`np.ndarray`):
  240. Image to resize.
  241. size (`Dict[str, int]`):
  242. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  243. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  244. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
  245. data_format (`ChannelDimension` or `str`, *optional*):
  246. The channel dimension format for the output image. If unset, the channel dimension format of the input
  247. image is used. Can be one of:
  248. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  249. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  250. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  251. input_data_format (`ChannelDimension` or `str`, *optional*):
  252. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  253. from the input image. Can be one of:
  254. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  255. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  256. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  257. Returns:
  258. `np.ndarray`: The resized image.
  259. """
  260. image_height, image_width = get_image_size(image, input_data_format)
  261. target_height, target_width = size["height"], size["width"]
  262. if image_width <= target_width and image_height <= target_height:
  263. return image
  264. height_scale_factor = target_height / image_height
  265. width_scale_factor = target_width / image_width
  266. optimal_scale_factor = min(height_scale_factor, width_scale_factor)
  267. new_height = int(image_height * optimal_scale_factor)
  268. new_width = int(image_width * optimal_scale_factor)
  269. scaled_image = resize(
  270. image=image,
  271. size=(new_height, new_width),
  272. resample=resample,
  273. data_format=data_format,
  274. input_data_format=input_data_format,
  275. **kwargs,
  276. )
  277. return scaled_image
  278. def pad_image(
  279. self,
  280. image: np.ndarray,
  281. size: Dict[str, int],
  282. mode: str = "constant",
  283. constant_values: float = 1.0,
  284. data_format: Optional[Union[str, ChannelDimension]] = None,
  285. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  286. ) -> np.ndarray:
  287. """
  288. Pad an image to `(size["height"], size["width"])`.
  289. Args:
  290. image (`np.ndarray`):
  291. Image to pad.
  292. size (`Dict[str, int]`):
  293. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  294. data_format (`ChannelDimension` or `str`, *optional*):
  295. The data format of the output image. If unset, the same format as the input image is used.
  296. input_data_format (`ChannelDimension` or `str`, *optional*):
  297. The channel dimension format of the input image. If not provided, it will be inferred.
  298. """
  299. image_height, image_width = get_image_size(image, input_data_format)
  300. target_height, target_width = size["height"], size["width"]
  301. padding_top = 0
  302. padding_left = 0
  303. padding_bottom = target_height - image_height
  304. padding_right = target_width - image_width
  305. padded_image = pad(
  306. image,
  307. padding=((padding_top, padding_bottom), (padding_left, padding_right)),
  308. mode=mode,
  309. constant_values=constant_values,
  310. data_format=data_format,
  311. input_data_format=input_data_format,
  312. )
  313. return padded_image
  314. @filter_out_non_signature_kwargs()
  315. def preprocess(
  316. self,
  317. images,
  318. do_resize: Optional[bool] = None,
  319. size: Optional[Dict[str, int]] = None,
  320. resample: Optional[PILImageResampling] = None,
  321. do_pad: Optional[bool] = None,
  322. padding_value: Optional[float] = None,
  323. padding_mode: Optional[str] = None,
  324. do_normalize: Optional[bool] = None,
  325. image_mean: Optional[float] = None,
  326. image_std: Optional[float] = None,
  327. do_rescale: Optional[bool] = None,
  328. rescale_factor: Optional[float] = None,
  329. patch_size: Optional[Dict[str, int]] = None,
  330. data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
  331. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  332. return_tensors: Optional[TensorType] = None,
  333. ):
  334. """
  335. Utility function to preprocess the images and extract necessary information about original formats.
  336. Args:
  337. images (`ImageInput`):
  338. Images to preprocess. Expects a single image, a list or images or a list of lists of images. Pixel
  339. values range from 0 to 255, or between 0 and 1 if `do_rescale` is `False`.
  340. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  341. Whether to resize the image to `size`.
  342. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  343. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  344. resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
  345. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
  346. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  347. Whether to pad the image to `size`.
  348. padding_value (`float`, *optional*, defaults to `self.padding_value`):
  349. The value to pad the image with.
  350. padding_mode (`str`, *optional*, defaults to `self.padding_mode`):
  351. The padding mode to use when padding the image.
  352. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  353. Whether to normalize the image.
  354. image_mean (`float`, *optional*, defaults to `self.image_mean`):
  355. The mean to use when normalizing the image.
  356. image_std (`float`, *optional*, defaults to `self.image_std`):
  357. The standard deviation to use when normalizing the image.
  358. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  359. Whether to rescale the image.
  360. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  361. The factor to use when rescaling the image.
  362. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
  363. Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
  364. return_tensors (`str` or `TensorType`, *optional*):
  365. The type of tensors to return. Can be one of:
  366. - Unset: Return a list of `np.ndarray`.
  367. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  368. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  369. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  370. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  371. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  372. The channel dimension format of the output image. Can be one of:
  373. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  374. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  375. input_data_format (`ChannelDimension` or `str`, *optional*):
  376. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  377. from the input image. Can be one of:
  378. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  379. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  380. """
  381. do_resize = do_resize if do_resize is not None else self.do_resize
  382. size = size if size is not None else self.size
  383. resample = resample if resample is not None else self.resample
  384. do_pad = do_pad if do_pad is not None else self.do_pad
  385. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  386. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  387. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  388. image_mean = image_mean if image_mean is not None else self.image_mean
  389. image_std = image_std if image_std is not None else self.image_std
  390. padding_value = padding_value if padding_value is not None else self.padding_value
  391. padding_mode = padding_mode if padding_mode is not None else self.padding_mode
  392. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  393. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  394. patch_size = patch_size if patch_size is not None else self.patch_size
  395. if isinstance(images, list) and any(isinstance(elem, list) and len(elem) >= 2 for elem in images):
  396. raise ValueError("Multiple images for a single sample are not yet supported.")
  397. batch_images = make_list_of_list_of_images(images)
  398. validate_preprocess_arguments(
  399. do_rescale=do_rescale,
  400. rescale_factor=rescale_factor,
  401. do_normalize=do_normalize,
  402. image_mean=image_mean,
  403. image_std=image_std,
  404. do_pad=do_pad,
  405. size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
  406. do_resize=do_resize,
  407. size=size,
  408. resample=resample,
  409. )
  410. # All transformations expect numpy arrays.
  411. batch_images = [[to_numpy_array(image) for image in images] for images in batch_images]
  412. if is_scaled_image(batch_images[0][0]) and do_rescale:
  413. logger.warning_once(
  414. "It looks like you are trying to rescale already rescaled images. If the input"
  415. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  416. )
  417. if input_data_format is None:
  418. # We assume that all images have the same channel dimension format.
  419. input_data_format = infer_channel_dimension_format(batch_images[0][0])
  420. original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
  421. if do_resize:
  422. batch_images = [
  423. [self.resize(image, size=size, input_data_format=input_data_format) for image in images]
  424. for images in batch_images
  425. ]
  426. image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
  427. image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
  428. image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
  429. # scale_h is the same as scale_w
  430. image_scale_factors = [
  431. [resized_size[0] / original_size[0]]
  432. for original_size, resized_size in zip(original_image_sizes, image_sizes)
  433. ]
  434. if do_pad:
  435. batch_images = [
  436. [
  437. self.pad_image(
  438. image,
  439. size=size,
  440. mode=padding_mode,
  441. constant_values=padding_value,
  442. input_data_format=input_data_format,
  443. )
  444. for image in images
  445. ]
  446. for images in batch_images
  447. ]
  448. if do_rescale:
  449. batch_images = [
  450. [self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images]
  451. for images in batch_images
  452. ]
  453. if do_normalize:
  454. batch_images = [
  455. [
  456. self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  457. for image in images
  458. ]
  459. for images in batch_images
  460. ]
  461. if data_format is not None:
  462. batch_images = [
  463. [to_channel_dimension_format(image, data_format, input_data_format) for image in images]
  464. for images in batch_images
  465. ]
  466. data = {
  467. "images": batch_images,
  468. "image_unpadded_heights": image_unpadded_heights,
  469. "image_unpadded_widths": image_unpadded_widths,
  470. "image_scale_factors": image_scale_factors,
  471. }
  472. return FuyuBatchFeature(data=data, tensor_type=return_tensors)
  473. def get_num_patches(self, image_height: int, image_width: int, patch_size: Dict[str, int] = None) -> int:
  474. """
  475. Calculate number of patches required to encode an image.
  476. Args:
  477. image_height (`int`):
  478. Height of the image.
  479. image_width (`int`):
  480. Width of the image.
  481. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
  482. Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
  483. """
  484. patch_size = patch_size if patch_size is not None else self.patch_size
  485. patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
  486. if image_height % patch_height != 0:
  487. raise ValueError(f"{image_height=} must be divisible by {patch_height}")
  488. if image_width % patch_width != 0:
  489. raise ValueError(f"{image_width=} must be divisible by {patch_width}")
  490. num_patches_per_dim_h = image_height // patch_height
  491. num_patches_per_dim_w = image_width // patch_width
  492. num_patches = num_patches_per_dim_h * num_patches_per_dim_w
  493. return num_patches
  494. def patchify_image(self, image: "torch.Tensor", patch_size: Optional[Dict[str, int]] = None) -> "torch.Tensor":
  495. """
  496. Convert an image into a tensor of patches.
  497. Args:
  498. image (`torch.Tensor`):
  499. Image to convert. Shape: [batch, channels, height, width]
  500. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
  501. Dictionary in the format `{"height": int, "width": int}` specifying the size of the patches.
  502. """
  503. requires_backends(self, ["torch"])
  504. patch_size = patch_size if patch_size is not None else self.patch_size
  505. patch_height, patch_width = patch_size["height"], patch_size["width"]
  506. # TODO refer to https://github.com/ArthurZucker/transformers/blob/0f0a3fe5ca5697ee58faeb5b53f049af720b5e98/src/transformers/models/vit_mae/modeling_vit_mae.py#L871
  507. # torch implementation is faster but does not handle non-squares
  508. batch_size, channels, _, _ = image.shape
  509. unfolded_along_height = image.unfold(2, patch_height, patch_height)
  510. patches = unfolded_along_height.unfold(3, patch_width, patch_width)
  511. patches = patches.contiguous()
  512. patches = patches.view(batch_size, channels, -1, patch_height, patch_width)
  513. patches = patches.permute(0, 2, 3, 4, 1)
  514. patches = patches.reshape(batch_size, -1, channels * patch_height * patch_width)
  515. return patches
  516. def preprocess_with_tokenizer_info(
  517. self,
  518. image_input: "torch.Tensor",
  519. image_present: "torch.Tensor",
  520. image_unpadded_h: "torch.Tensor",
  521. image_unpadded_w: "torch.Tensor",
  522. image_placeholder_id: int,
  523. image_newline_id: int,
  524. variable_sized: bool,
  525. patch_size: Optional[Dict[str, int]] = None,
  526. ) -> FuyuBatchFeature:
  527. """Process images for model input. In particular, variable-sized images are handled here.
  528. Args:
  529. image_input (`torch.Tensor` of shape [batch_size, subsequence_size, num_channels, height, width]):
  530. Tensor of images padded to model input size.
  531. image_present (`torch.Tensor` of shape [batch_size, subsequence_size, num_images]):
  532. Tensor of 1s and 0s indicating whether an image is present.
  533. image_unpadded_h (`torch.Tensor` of shape [batch_size, subsequence_size]):
  534. Tensor of unpadded image heights.
  535. image_unpadded_w (`torch.Tensor` of shape [batch_size, subsequence_size]):
  536. Tensor of unpadded image widths.
  537. image_placeholder_id (int):
  538. The id of the image placeholder token. Comes from an associated tokenizer.
  539. image_newline_id (int):
  540. The id of the image newline token. Comes from an associated tokenizer.
  541. variable_sized (bool):
  542. Whether to process images as variable-sized.
  543. patch_size (`Dict[str, int]`, *optional*, defaults to `self.patch_size`):
  544. Size of the patches.
  545. """
  546. requires_backends(self, ["torch"])
  547. patch_size = patch_size if patch_size is not None else self.patch_size
  548. patch_height, patch_width = patch_size["height"], patch_size["width"]
  549. # Only images that are present.
  550. images: List[List[torch.Tensor]] = []
  551. batch_image_patches: List[List[torch.Tensor]] = []
  552. # Image input ids for every subsequence, including ones with no image present.
  553. batch_image_input_ids: List[List[torch.Tensor]] = []
  554. for batch_index in range(image_input.shape[0]):
  555. image_input_ids = []
  556. image_patches = []
  557. for subseq_index in range(image_input.shape[1]):
  558. if image_present[batch_index, subseq_index]:
  559. image = image_input[batch_index, subseq_index]
  560. image_height, image_width = image.shape[1], image.shape[2]
  561. if variable_sized:
  562. # The min() is required here due to floating point issues:
  563. # math.ceil(torch.tensor(300).cuda() / 30) == 11
  564. new_h = min(
  565. image_height,
  566. math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
  567. )
  568. new_w = min(
  569. image_width,
  570. math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
  571. )
  572. image = image[:, :new_h, :new_w]
  573. image_height, image_width = new_h, new_w
  574. num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
  575. tensor_of_image_ids = torch.full(
  576. [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
  577. )
  578. patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
  579. assert num_patches == patches.shape[0]
  580. if variable_sized:
  581. # Now terminate each line with |NEWLINE|.
  582. tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
  583. newline_ids = torch.full(
  584. [tensor_of_image_ids.shape[0], 1],
  585. image_newline_id,
  586. dtype=torch.int32,
  587. device=image_input.device,
  588. )
  589. tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
  590. tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
  591. images.append([image])
  592. image_input_ids.append(tensor_of_image_ids)
  593. image_patches.append(patches)
  594. else:
  595. image_input_ids.append(torch.tensor([], dtype=torch.int32, device=image_input.device))
  596. batch_image_input_ids.append(image_input_ids)
  597. batch_image_patches.append(image_patches)
  598. # Create image_patch_input_indices, where non-negative values correspond to image patches to be inserted in
  599. # the stream.
  600. image_patch_indices_per_batch: List[List[torch.Tensor]] = []
  601. image_patch_indices_per_subsequence: List[List[torch.Tensor]] = []
  602. for sample_image_input_ids in batch_image_input_ids:
  603. index_offset = 0
  604. per_batch_indices = []
  605. per_subsequence_indices = []
  606. for subseq_image_input_ids in sample_image_input_ids:
  607. # Indices of image patches.
  608. patches_mask = subseq_image_input_ids == image_placeholder_id
  609. num_patches = torch.count_nonzero(patches_mask)
  610. indices = torch.arange(num_patches, dtype=torch.int64, device=subseq_image_input_ids.device).type_as(
  611. subseq_image_input_ids
  612. )
  613. # Place those indices in the image input ids token stream, with -1 representing non-index tokens.
  614. indices_in_stream_per_batch = torch.full_like(subseq_image_input_ids, -1)
  615. indices_in_stream_per_subsequence = torch.full_like(subseq_image_input_ids, -1)
  616. patches_inds = torch.nonzero(patches_mask, as_tuple=True)[0]
  617. indices_in_stream_per_batch[patches_inds] = indices + index_offset
  618. indices_in_stream_per_subsequence[patches_inds] = indices
  619. per_batch_indices.append(indices_in_stream_per_batch)
  620. per_subsequence_indices.append(indices_in_stream_per_subsequence)
  621. index_offset += num_patches
  622. image_patch_indices_per_batch.append(per_batch_indices)
  623. image_patch_indices_per_subsequence.append(per_subsequence_indices)
  624. return FuyuBatchFeature(
  625. data={
  626. "images": images,
  627. "image_input_ids": batch_image_input_ids,
  628. "image_patches": batch_image_patches,
  629. "image_patch_indices_per_batch": image_patch_indices_per_batch,
  630. "image_patch_indices_per_subsequence": image_patch_indices_per_subsequence,
  631. }
  632. )