image_processing_imagegpt.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for ImageGPT."""
  16. from typing import Dict, List, Optional, Union
  17. import numpy as np
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  19. from ...image_transforms import rescale, resize, to_channel_dimension_format
  20. from ...image_utils import (
  21. ChannelDimension,
  22. ImageInput,
  23. PILImageResampling,
  24. infer_channel_dimension_format,
  25. is_scaled_image,
  26. make_list_of_images,
  27. to_numpy_array,
  28. valid_images,
  29. validate_preprocess_arguments,
  30. )
  31. from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
  32. if is_vision_available():
  33. import PIL
  34. logger = logging.get_logger(__name__)
  35. def squared_euclidean_distance(a, b):
  36. b = b.T
  37. a2 = np.sum(np.square(a), axis=1)
  38. b2 = np.sum(np.square(b), axis=0)
  39. ab = np.matmul(a, b)
  40. d = a2[:, None] - 2 * ab + b2[None, :]
  41. return d
  42. def color_quantize(x, clusters):
  43. x = x.reshape(-1, 3)
  44. d = squared_euclidean_distance(x, clusters)
  45. return np.argmin(d, axis=1)
  46. class ImageGPTImageProcessor(BaseImageProcessor):
  47. r"""
  48. Constructs a ImageGPT image processor. This image processor can be used to resize images to a smaller resolution
  49. (such as 32x32 or 64x64), normalize them and finally color quantize them to obtain sequences of "pixel values"
  50. (color clusters).
  51. Args:
  52. clusters (`np.ndarray` or `List[List[int]]`, *optional*):
  53. The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overriden by `clusters`
  54. in `preprocess`.
  55. do_resize (`bool`, *optional*, defaults to `True`):
  56. Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
  57. `do_resize` in `preprocess`.
  58. size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 256}`):
  59. Size of the image after resizing. Can be overridden by `size` in `preprocess`.
  60. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
  61. Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
  62. do_normalize (`bool`, *optional*, defaults to `True`):
  63. Whether to normalize the image pixel value to between [-1, 1]. Can be overridden by `do_normalize` in
  64. `preprocess`.
  65. do_color_quantize (`bool`, *optional*, defaults to `True`):
  66. Whether to color quantize the image. Can be overridden by `do_color_quantize` in `preprocess`.
  67. """
  68. model_input_names = ["pixel_values"]
  69. def __init__(
  70. self,
  71. # clusters is a first argument to maintain backwards compatibility with the old ImageGPTImageProcessor
  72. clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
  73. do_resize: bool = True,
  74. size: Dict[str, int] = None,
  75. resample: PILImageResampling = PILImageResampling.BILINEAR,
  76. do_normalize: bool = True,
  77. do_color_quantize: bool = True,
  78. **kwargs,
  79. ) -> None:
  80. super().__init__(**kwargs)
  81. size = size if size is not None else {"height": 256, "width": 256}
  82. size = get_size_dict(size)
  83. self.clusters = np.array(clusters) if clusters is not None else None
  84. self.do_resize = do_resize
  85. self.size = size
  86. self.resample = resample
  87. self.do_normalize = do_normalize
  88. self.do_color_quantize = do_color_quantize
  89. # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize
  90. def resize(
  91. self,
  92. image: np.ndarray,
  93. size: Dict[str, int],
  94. resample: PILImageResampling = PILImageResampling.BILINEAR,
  95. data_format: Optional[Union[str, ChannelDimension]] = None,
  96. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  97. **kwargs,
  98. ) -> np.ndarray:
  99. """
  100. Resize an image to `(size["height"], size["width"])`.
  101. Args:
  102. image (`np.ndarray`):
  103. Image to resize.
  104. size (`Dict[str, int]`):
  105. Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
  106. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  107. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
  108. data_format (`ChannelDimension` or `str`, *optional*):
  109. The channel dimension format for the output image. If unset, the channel dimension format of the input
  110. image is used. Can be one of:
  111. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  112. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  113. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  114. input_data_format (`ChannelDimension` or `str`, *optional*):
  115. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  116. from the input image. Can be one of:
  117. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  118. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  119. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  120. Returns:
  121. `np.ndarray`: The resized image.
  122. """
  123. size = get_size_dict(size)
  124. if "height" not in size or "width" not in size:
  125. raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
  126. output_size = (size["height"], size["width"])
  127. return resize(
  128. image,
  129. size=output_size,
  130. resample=resample,
  131. data_format=data_format,
  132. input_data_format=input_data_format,
  133. **kwargs,
  134. )
  135. def normalize(
  136. self,
  137. image: np.ndarray,
  138. data_format: Optional[Union[str, ChannelDimension]] = None,
  139. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  140. ) -> np.ndarray:
  141. """
  142. Normalizes an images' pixel values to between [-1, 1].
  143. Args:
  144. image (`np.ndarray`):
  145. Image to normalize.
  146. data_format (`str` or `ChannelDimension`, *optional*):
  147. The channel dimension format of the image. If not provided, it will be the same as the input image.
  148. input_data_format (`ChannelDimension` or `str`, *optional*):
  149. The channel dimension format of the input image. If not provided, it will be inferred.
  150. """
  151. image = rescale(image=image, scale=1 / 127.5, data_format=data_format, input_data_format=input_data_format)
  152. image = image - 1
  153. return image
  154. @filter_out_non_signature_kwargs()
  155. def preprocess(
  156. self,
  157. images: ImageInput,
  158. do_resize: bool = None,
  159. size: Dict[str, int] = None,
  160. resample: PILImageResampling = None,
  161. do_normalize: bool = None,
  162. do_color_quantize: Optional[bool] = None,
  163. clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
  164. return_tensors: Optional[Union[str, TensorType]] = None,
  165. data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
  166. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  167. ) -> PIL.Image.Image:
  168. """
  169. Preprocess an image or batch of images.
  170. Args:
  171. images (`ImageInput`):
  172. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  173. passing in images with pixel values between 0 and 1, set `do_normalize=False`.
  174. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  175. Whether to resize the image.
  176. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  177. Size of the image after resizing.
  178. resample (`int`, *optional*, defaults to `self.resample`):
  179. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
  180. has an effect if `do_resize` is set to `True`.
  181. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  182. Whether to normalize the image
  183. do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
  184. Whether to color quantize the image.
  185. clusters (`np.ndarray` or `List[List[int]]`, *optional*, defaults to `self.clusters`):
  186. Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
  187. `do_color_quantize` is set to `True`.
  188. return_tensors (`str` or `TensorType`, *optional*):
  189. The type of tensors to return. Can be one of:
  190. - Unset: Return a list of `np.ndarray`.
  191. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  192. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  193. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  194. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  195. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  196. The channel dimension format for the output image. Can be one of:
  197. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  198. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  199. Only has an effect if `do_color_quantize` is set to `False`.
  200. input_data_format (`ChannelDimension` or `str`, *optional*):
  201. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  202. from the input image. Can be one of:
  203. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  204. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  205. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  206. """
  207. do_resize = do_resize if do_resize is not None else self.do_resize
  208. size = size if size is not None else self.size
  209. size = get_size_dict(size)
  210. resample = resample if resample is not None else self.resample
  211. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  212. do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
  213. clusters = clusters if clusters is not None else self.clusters
  214. clusters = np.array(clusters)
  215. images = make_list_of_images(images)
  216. if not valid_images(images):
  217. raise ValueError(
  218. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  219. "torch.Tensor, tf.Tensor or jax.ndarray."
  220. )
  221. # Here, normalize() is using a constant factor to divide pixel values.
  222. # hence, the method does not need iamge_mean and image_std.
  223. validate_preprocess_arguments(
  224. do_resize=do_resize,
  225. size=size,
  226. resample=resample,
  227. )
  228. if do_color_quantize and clusters is None:
  229. raise ValueError("Clusters must be specified if do_color_quantize is True.")
  230. # All transformations expect numpy arrays.
  231. images = [to_numpy_array(image) for image in images]
  232. if is_scaled_image(images[0]) and do_normalize:
  233. logger.warning_once(
  234. "It looks like you are trying to rescale already rescaled images. If you wish to do this, "
  235. "make sure to set `do_normalize` to `False` and that pixel values are between [-1, 1].",
  236. )
  237. if input_data_format is None:
  238. # We assume that all images have the same channel dimension format.
  239. input_data_format = infer_channel_dimension_format(images[0])
  240. if do_resize:
  241. images = [
  242. self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
  243. for image in images
  244. ]
  245. if do_normalize:
  246. images = [self.normalize(image=image, input_data_format=input_data_format) for image in images]
  247. if do_color_quantize:
  248. images = [to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) for image in images]
  249. # color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
  250. images = np.array(images)
  251. images = color_quantize(images, clusters).reshape(images.shape[:-1])
  252. # flatten to (batch_size, height*width)
  253. batch_size = images.shape[0]
  254. images = images.reshape(batch_size, -1)
  255. # We need to convert back to a list of images to keep consistent behaviour across processors.
  256. images = list(images)
  257. else:
  258. images = [
  259. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
  260. for image in images
  261. ]
  262. data = {"input_ids": images}
  263. return BatchFeature(data=data, tensor_type=return_tensors)