image_processing_idefics.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for Idefics."""
  16. from typing import Callable, Dict, List, Optional, Union
  17. from PIL import Image
  18. from ...image_processing_utils import BaseImageProcessor, BatchFeature
  19. from ...image_transforms import resize, to_channel_dimension_format
  20. from ...image_utils import (
  21. ChannelDimension,
  22. ImageInput,
  23. PILImageResampling,
  24. make_list_of_images,
  25. to_numpy_array,
  26. valid_images,
  27. )
  28. from ...utils import TensorType, is_torch_available
  29. IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
  30. IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
  31. def convert_to_rgb(image):
  32. # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
  33. # for transparent images. The call to `alpha_composite` handles this case
  34. if image.mode == "RGB":
  35. return image
  36. image_rgba = image.convert("RGBA")
  37. background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
  38. alpha_composite = Image.alpha_composite(background, image_rgba)
  39. alpha_composite = alpha_composite.convert("RGB")
  40. return alpha_composite
  41. class IdeficsImageProcessor(BaseImageProcessor):
  42. r"""
  43. Constructs a Idefics image processor.
  44. Args:
  45. image_size (`int`, *optional*, defaults to 224):
  46. Resize to image size
  47. image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
  48. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  49. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
  50. overridden by the `image_mean` parameter in the `preprocess` method.
  51. image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
  52. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  53. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  54. Can be overridden by the `image_std` parameter in the `preprocess` method.
  55. image_num_channels (`int`, *optional*, defaults to 3):
  56. Number of image channels.
  57. """
  58. model_input_names = ["pixel_values"]
  59. def __init__(
  60. self,
  61. image_size: int = 224,
  62. image_mean: Optional[Union[float, List[float]]] = None,
  63. image_std: Optional[Union[float, List[float]]] = None,
  64. image_num_channels: Optional[int] = 3,
  65. **kwargs,
  66. ) -> None:
  67. super().__init__(**kwargs)
  68. self.image_size = image_size
  69. self.image_num_channels = image_num_channels
  70. self.image_mean = image_mean
  71. self.image_std = image_std
  72. def preprocess(
  73. self,
  74. images: ImageInput,
  75. image_num_channels: Optional[int] = 3,
  76. image_size: Optional[Dict[str, int]] = None,
  77. image_mean: Optional[Union[float, List[float]]] = None,
  78. image_std: Optional[Union[float, List[float]]] = None,
  79. transform: Callable = None,
  80. return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
  81. **kwargs,
  82. ) -> TensorType:
  83. """
  84. Preprocess a batch of images.
  85. Args:
  86. images (`ImageInput`):
  87. A list of images to preprocess.
  88. image_size (`int`, *optional*, defaults to `self.image_size`):
  89. Resize to image size
  90. image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
  91. Number of image channels.
  92. image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
  93. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  94. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
  95. be overridden by the `image_mean` parameter in the `preprocess` method.
  96. image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
  97. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  98. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
  99. method. Can be overridden by the `image_std` parameter in the `preprocess` method.
  100. transform (`Callable`, *optional*, defaults to `None`):
  101. A custom transform function that accepts a single image can be passed for training. For example,
  102. `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
  103. assumed - and then a preset of inference-specific transforms will be applied to the images
  104. Returns:
  105. a PyTorch tensor of the processed images
  106. """
  107. image_size = image_size if image_size is not None else self.image_size
  108. image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels
  109. image_mean = image_mean if image_mean is not None else self.image_mean
  110. image_std = image_std if image_std is not None else self.image_std
  111. size = (image_size, image_size)
  112. if isinstance(images, list) and len(images) == 0:
  113. return []
  114. images = make_list_of_images(images)
  115. if not valid_images(images):
  116. raise ValueError(
  117. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  118. "torch.Tensor, tf.Tensor or jax.ndarray."
  119. )
  120. # For training a user needs to pass their own set of transforms as a Callable.
  121. # For reference this is what was used in the original IDEFICS training:
  122. # transform = transforms.Compose([
  123. # convert_to_rgb,
  124. # transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
  125. # transforms.ToTensor(),
  126. # transforms.Normalize(mean=image_mean, std=image_std),
  127. # ])
  128. if transform is not None:
  129. if not is_torch_available():
  130. raise ImportError("To pass in `transform` torch must be installed")
  131. import torch
  132. images = [transform(x) for x in images]
  133. return torch.stack(images)
  134. # for inference we do the exact transforms that were used to train IDEFICS
  135. images = [convert_to_rgb(x) for x in images]
  136. # further transforms expect numpy arrays
  137. images = [to_numpy_array(x) for x in images]
  138. images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
  139. images = [self.rescale(image=image, scale=1 / 255) for image in images]
  140. images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
  141. images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
  142. images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"]
  143. return images