| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811 |
- # coding=utf-8
- # Copyright 2021 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import base64
- import os
- from io import BytesIO
- from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
- import numpy as np
- import requests
- from packaging import version
- from .utils import (
- ExplicitEnum,
- is_jax_tensor,
- is_numpy_array,
- is_tf_tensor,
- is_torch_available,
- is_torch_tensor,
- is_torchvision_available,
- is_vision_available,
- logging,
- requires_backends,
- to_numpy,
- )
- from .utils.constants import ( # noqa: F401
- IMAGENET_DEFAULT_MEAN,
- IMAGENET_DEFAULT_STD,
- IMAGENET_STANDARD_MEAN,
- IMAGENET_STANDARD_STD,
- OPENAI_CLIP_MEAN,
- OPENAI_CLIP_STD,
- )
- if is_vision_available():
- import PIL.Image
- import PIL.ImageOps
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
- PILImageResampling = PIL.Image.Resampling
- else:
- PILImageResampling = PIL.Image
- if is_torchvision_available():
- from torchvision.transforms import InterpolationMode
- pil_torch_interpolation_mapping = {
- PILImageResampling.NEAREST: InterpolationMode.NEAREST,
- PILImageResampling.BOX: InterpolationMode.BOX,
- PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
- PILImageResampling.HAMMING: InterpolationMode.HAMMING,
- PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
- PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
- }
- if TYPE_CHECKING:
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- ImageInput = Union[
- "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
- ] # noqa
- VideoInput = Union[
- List["PIL.Image.Image"],
- "np.ndarray",
- "torch.Tensor",
- List["np.ndarray"],
- List["torch.Tensor"],
- List[List["PIL.Image.Image"]],
- List[List["np.ndarrray"]],
- List[List["torch.Tensor"]],
- ] # noqa
- class ChannelDimension(ExplicitEnum):
- FIRST = "channels_first"
- LAST = "channels_last"
- class AnnotationFormat(ExplicitEnum):
- COCO_DETECTION = "coco_detection"
- COCO_PANOPTIC = "coco_panoptic"
- class AnnotionFormat(ExplicitEnum):
- COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
- COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
- AnnotationType = Dict[str, Union[int, str, List[Dict]]]
- def is_pil_image(img):
- return is_vision_available() and isinstance(img, PIL.Image.Image)
- class ImageType(ExplicitEnum):
- PIL = "pillow"
- TORCH = "torch"
- NUMPY = "numpy"
- TENSORFLOW = "tensorflow"
- JAX = "jax"
- def get_image_type(image):
- if is_pil_image(image):
- return ImageType.PIL
- if is_torch_tensor(image):
- return ImageType.TORCH
- if is_numpy_array(image):
- return ImageType.NUMPY
- if is_tf_tensor(image):
- return ImageType.TENSORFLOW
- if is_jax_tensor(image):
- return ImageType.JAX
- raise ValueError(f"Unrecognised image type {type(image)}")
- def is_valid_image(img):
- return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
- def valid_images(imgs):
- # If we have an list of images, make sure every image is valid
- if isinstance(imgs, (list, tuple)):
- for img in imgs:
- if not valid_images(img):
- return False
- # If not a list of tuple, we have been given a single image or batched tensor of images
- elif not is_valid_image(imgs):
- return False
- return True
- def is_batched(img):
- if isinstance(img, (list, tuple)):
- return is_valid_image(img[0])
- return False
- def is_scaled_image(image: np.ndarray) -> bool:
- """
- Checks to see whether the pixel values have already been rescaled to [0, 1].
- """
- if image.dtype == np.uint8:
- return False
- # It's possible the image has pixel values in [0, 255] but is of floating type
- return np.min(image) >= 0 and np.max(image) <= 1
- def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
- """
- Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
- If the input is a batch of images, it is converted to a list of images.
- Args:
- images (`ImageInput`):
- Image of images to turn into a list of images.
- expected_ndims (`int`, *optional*, defaults to 3):
- Expected number of dimensions for a single input image. If the input image has a different number of
- dimensions, an error is raised.
- """
- if is_batched(images):
- return images
- # Either the input is a single image, in which case we create a list of length 1
- if isinstance(images, PIL.Image.Image):
- # PIL images are never batched
- return [images]
- if is_valid_image(images):
- if images.ndim == expected_ndims + 1:
- # Batch of images
- images = list(images)
- elif images.ndim == expected_ndims:
- # Single image
- images = [images]
- else:
- raise ValueError(
- f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
- f" {images.ndim} dimensions."
- )
- return images
- raise ValueError(
- "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
- f"jax.ndarray, but got {type(images)}."
- )
- def to_numpy_array(img) -> np.ndarray:
- if not is_valid_image(img):
- raise ValueError(f"Invalid image type: {type(img)}")
- if is_vision_available() and isinstance(img, PIL.Image.Image):
- return np.array(img)
- return to_numpy(img)
- def infer_channel_dimension_format(
- image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
- ) -> ChannelDimension:
- """
- Infers the channel dimension format of `image`.
- Args:
- image (`np.ndarray`):
- The image to infer the channel dimension of.
- num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
- The number of channels of the image.
- Returns:
- The channel dimension of the image.
- """
- num_channels = num_channels if num_channels is not None else (1, 3)
- num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
- if image.ndim == 3:
- first_dim, last_dim = 0, 2
- elif image.ndim == 4:
- first_dim, last_dim = 1, 3
- else:
- raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
- if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
- logger.warning(
- f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
- )
- return ChannelDimension.FIRST
- elif image.shape[first_dim] in num_channels:
- return ChannelDimension.FIRST
- elif image.shape[last_dim] in num_channels:
- return ChannelDimension.LAST
- raise ValueError("Unable to infer channel dimension format")
- def get_channel_dimension_axis(
- image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
- ) -> int:
- """
- Returns the channel dimension axis of the image.
- Args:
- image (`np.ndarray`):
- The image to get the channel dimension axis of.
- input_data_format (`ChannelDimension` or `str`, *optional*):
- The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
- Returns:
- The channel dimension axis of the image.
- """
- if input_data_format is None:
- input_data_format = infer_channel_dimension_format(image)
- if input_data_format == ChannelDimension.FIRST:
- return image.ndim - 3
- elif input_data_format == ChannelDimension.LAST:
- return image.ndim - 1
- raise ValueError(f"Unsupported data format: {input_data_format}")
- def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
- """
- Returns the (height, width) dimensions of the image.
- Args:
- image (`np.ndarray`):
- The image to get the dimensions of.
- channel_dim (`ChannelDimension`, *optional*):
- Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
- Returns:
- A tuple of the image's height and width.
- """
- if channel_dim is None:
- channel_dim = infer_channel_dimension_format(image)
- if channel_dim == ChannelDimension.FIRST:
- return image.shape[-2], image.shape[-1]
- elif channel_dim == ChannelDimension.LAST:
- return image.shape[-3], image.shape[-2]
- else:
- raise ValueError(f"Unsupported data format: {channel_dim}")
- def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
- if (
- isinstance(annotation, dict)
- and "image_id" in annotation
- and "annotations" in annotation
- and isinstance(annotation["annotations"], (list, tuple))
- and (
- # an image can have no annotations
- len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
- )
- ):
- return True
- return False
- def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
- if (
- isinstance(annotation, dict)
- and "image_id" in annotation
- and "segments_info" in annotation
- and "file_name" in annotation
- and isinstance(annotation["segments_info"], (list, tuple))
- and (
- # an image can have no segments
- len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
- )
- ):
- return True
- return False
- def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
- return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
- def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
- return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
- def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
- """
- Loads `image` to a PIL Image.
- Args:
- image (`str` or `PIL.Image.Image`):
- The image to convert to the PIL Image format.
- timeout (`float`, *optional*):
- The timeout value in seconds for the URL request.
- Returns:
- `PIL.Image.Image`: A PIL Image.
- """
- requires_backends(load_image, ["vision"])
- if isinstance(image, str):
- if image.startswith("http://") or image.startswith("https://"):
- # We need to actually check for a real protocol, otherwise it's impossible to use a local file
- # like http_huggingface_co.png
- image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
- elif os.path.isfile(image):
- image = PIL.Image.open(image)
- else:
- if image.startswith("data:image/"):
- image = image.split(",")[1]
- # Try to load as base64
- try:
- b64 = base64.decodebytes(image.encode())
- image = PIL.Image.open(BytesIO(b64))
- except Exception as e:
- raise ValueError(
- f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
- )
- elif isinstance(image, PIL.Image.Image):
- image = image
- else:
- raise TypeError(
- "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
- )
- image = PIL.ImageOps.exif_transpose(image)
- image = image.convert("RGB")
- return image
- def validate_preprocess_arguments(
- do_rescale: Optional[bool] = None,
- rescale_factor: Optional[float] = None,
- do_normalize: Optional[bool] = None,
- image_mean: Optional[Union[float, List[float]]] = None,
- image_std: Optional[Union[float, List[float]]] = None,
- do_pad: Optional[bool] = None,
- size_divisibility: Optional[int] = None,
- do_center_crop: Optional[bool] = None,
- crop_size: Optional[Dict[str, int]] = None,
- do_resize: Optional[bool] = None,
- size: Optional[Dict[str, int]] = None,
- resample: Optional["PILImageResampling"] = None,
- ):
- """
- Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
- Raises `ValueError` if arguments incompatibility is caught.
- Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
- sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
- existing arguments when possible.
- """
- if do_rescale and rescale_factor is None:
- raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
- if do_pad and size_divisibility is None:
- # Here, size_divisor might be passed as the value of size
- raise ValueError(
- "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
- )
- if do_normalize and (image_mean is None or image_std is None):
- raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
- if do_center_crop and crop_size is None:
- raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
- if do_resize and (size is None or resample is None):
- raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
- # In the future we can add a TF implementation here when we have TF models.
- class ImageFeatureExtractionMixin:
- """
- Mixin that contain utilities for preparing image features.
- """
- def _ensure_format_supported(self, image):
- if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
- raise ValueError(
- f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
- "`torch.Tensor` are."
- )
- def to_pil_image(self, image, rescale=None):
- """
- Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
- needed.
- Args:
- image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
- The image to convert to the PIL Image format.
- rescale (`bool`, *optional*):
- Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
- default to `True` if the image type is a floating type, `False` otherwise.
- """
- self._ensure_format_supported(image)
- if is_torch_tensor(image):
- image = image.numpy()
- if isinstance(image, np.ndarray):
- if rescale is None:
- # rescale default to the array being of floating type.
- rescale = isinstance(image.flat[0], np.floating)
- # If the channel as been moved to first dim, we put it back at the end.
- if image.ndim == 3 and image.shape[0] in [1, 3]:
- image = image.transpose(1, 2, 0)
- if rescale:
- image = image * 255
- image = image.astype(np.uint8)
- return PIL.Image.fromarray(image)
- return image
- def convert_rgb(self, image):
- """
- Converts `PIL.Image.Image` to RGB format.
- Args:
- image (`PIL.Image.Image`):
- The image to convert.
- """
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- return image
- return image.convert("RGB")
- def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
- """
- Rescale a numpy image by scale amount
- """
- self._ensure_format_supported(image)
- return image * scale
- def to_numpy_array(self, image, rescale=None, channel_first=True):
- """
- Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
- dimension.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to convert to a NumPy array.
- rescale (`bool`, *optional*):
- Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
- default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
- channel_first (`bool`, *optional*, defaults to `True`):
- Whether or not to permute the dimensions of the image to put the channel dimension first.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = np.array(image)
- if is_torch_tensor(image):
- image = image.numpy()
- rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
- if rescale:
- image = self.rescale(image.astype(np.float32), 1 / 255.0)
- if channel_first and image.ndim == 3:
- image = image.transpose(2, 0, 1)
- return image
- def expand_dims(self, image):
- """
- Expands 2-dimensional `image` to 3 dimensions.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to expand.
- """
- self._ensure_format_supported(image)
- # Do nothing if PIL image
- if isinstance(image, PIL.Image.Image):
- return image
- if is_torch_tensor(image):
- image = image.unsqueeze(0)
- else:
- image = np.expand_dims(image, axis=0)
- return image
- def normalize(self, image, mean, std, rescale=False):
- """
- Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
- if it's a PIL Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to normalize.
- mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
- The mean (per channel) to use for normalization.
- std (`List[float]` or `np.ndarray` or `torch.Tensor`):
- The standard deviation (per channel) to use for normalization.
- rescale (`bool`, *optional*, defaults to `False`):
- Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
- happen automatically.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = self.to_numpy_array(image, rescale=True)
- # If the input image is a PIL image, it automatically gets rescaled. If it's another
- # type it may need rescaling.
- elif rescale:
- if isinstance(image, np.ndarray):
- image = self.rescale(image.astype(np.float32), 1 / 255.0)
- elif is_torch_tensor(image):
- image = self.rescale(image.float(), 1 / 255.0)
- if isinstance(image, np.ndarray):
- if not isinstance(mean, np.ndarray):
- mean = np.array(mean).astype(image.dtype)
- if not isinstance(std, np.ndarray):
- std = np.array(std).astype(image.dtype)
- elif is_torch_tensor(image):
- import torch
- if not isinstance(mean, torch.Tensor):
- if isinstance(mean, np.ndarray):
- mean = torch.from_numpy(mean)
- else:
- mean = torch.tensor(mean)
- if not isinstance(std, torch.Tensor):
- if isinstance(std, np.ndarray):
- std = torch.from_numpy(std)
- else:
- std = torch.tensor(std)
- if image.ndim == 3 and image.shape[0] in [1, 3]:
- return (image - mean[:, None, None]) / std[:, None, None]
- else:
- return (image - mean) / std
- def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
- """
- Resizes `image`. Enforces conversion of input to PIL.Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to resize.
- size (`int` or `Tuple[int, int]`):
- The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
- matched to this.
- If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
- `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
- this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
- resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
- The filter to user for resampling.
- default_to_square (`bool`, *optional*, defaults to `True`):
- How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
- square (`size`,`size`). If set to `False`, will replicate
- [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
- with support for resizing only the smallest edge and providing an optional `max_size`.
- max_size (`int`, *optional*, defaults to `None`):
- The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
- greater than `max_size` after being resized according to `size`, then the image is resized again so
- that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
- edge may be shorter than `size`. Only used if `default_to_square` is `False`.
- Returns:
- image: A resized `PIL.Image.Image`.
- """
- resample = resample if resample is not None else PILImageResampling.BILINEAR
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- image = self.to_pil_image(image)
- if isinstance(size, list):
- size = tuple(size)
- if isinstance(size, int) or len(size) == 1:
- if default_to_square:
- size = (size, size) if isinstance(size, int) else (size[0], size[0])
- else:
- width, height = image.size
- # specified size only for the smallest edge
- short, long = (width, height) if width <= height else (height, width)
- requested_new_short = size if isinstance(size, int) else size[0]
- if short == requested_new_short:
- return image
- new_short, new_long = requested_new_short, int(requested_new_short * long / short)
- if max_size is not None:
- if max_size <= requested_new_short:
- raise ValueError(
- f"max_size = {max_size} must be strictly greater than the requested "
- f"size for the smaller edge size = {size}"
- )
- if new_long > max_size:
- new_short, new_long = int(max_size * new_short / new_long), max_size
- size = (new_short, new_long) if width <= height else (new_long, new_short)
- return image.resize(size, resample=resample)
- def center_crop(self, image, size):
- """
- Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
- size given, it will be padded (so the returned result has the size asked).
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
- The image to resize.
- size (`int` or `Tuple[int, int]`):
- The size to which crop the image.
- Returns:
- new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
- height, width).
- """
- self._ensure_format_supported(image)
- if not isinstance(size, tuple):
- size = (size, size)
- # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
- if is_torch_tensor(image) or isinstance(image, np.ndarray):
- if image.ndim == 2:
- image = self.expand_dims(image)
- image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
- else:
- image_shape = (image.size[1], image.size[0])
- top = (image_shape[0] - size[0]) // 2
- bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
- left = (image_shape[1] - size[1]) // 2
- right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
- # For PIL Images we have a method to crop directly.
- if isinstance(image, PIL.Image.Image):
- return image.crop((left, top, right, bottom))
- # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
- channel_first = True if image.shape[0] in [1, 3] else False
- # Transpose (height, width, n_channels) format images
- if not channel_first:
- if isinstance(image, np.ndarray):
- image = image.transpose(2, 0, 1)
- if is_torch_tensor(image):
- image = image.permute(2, 0, 1)
- # Check if cropped area is within image boundaries
- if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
- return image[..., top:bottom, left:right]
- # Otherwise, we may need to pad if the image is too small. Oh joy...
- new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
- if isinstance(image, np.ndarray):
- new_image = np.zeros_like(image, shape=new_shape)
- elif is_torch_tensor(image):
- new_image = image.new_zeros(new_shape)
- top_pad = (new_shape[-2] - image_shape[0]) // 2
- bottom_pad = top_pad + image_shape[0]
- left_pad = (new_shape[-1] - image_shape[1]) // 2
- right_pad = left_pad + image_shape[1]
- new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
- top += top_pad
- bottom += top_pad
- left += left_pad
- right += left_pad
- new_image = new_image[
- ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
- ]
- return new_image
- def flip_channel_order(self, image):
- """
- Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
- `image` to a NumPy array if it's a PIL Image.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
- be first.
- """
- self._ensure_format_supported(image)
- if isinstance(image, PIL.Image.Image):
- image = self.to_numpy_array(image)
- return image[::-1, :, :]
- def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
- """
- Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
- counter clockwise around its centre.
- Args:
- image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
- The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
- rotating.
- Returns:
- image: A rotated `PIL.Image.Image`.
- """
- resample = resample if resample is not None else PIL.Image.NEAREST
- self._ensure_format_supported(image)
- if not isinstance(image, PIL.Image.Image):
- image = self.to_pil_image(image)
- return image.rotate(
- angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
- )
- def validate_annotations(
- annotation_format: AnnotationFormat,
- supported_annotation_formats: Tuple[AnnotationFormat, ...],
- annotations: List[Dict],
- ) -> None:
- if annotation_format not in supported_annotation_formats:
- raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
- if annotation_format is AnnotationFormat.COCO_DETECTION:
- if not valid_coco_detection_annotations(annotations):
- raise ValueError(
- "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
- "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
- "being a list of annotations in the COCO format."
- )
- if annotation_format is AnnotationFormat.COCO_PANOPTIC:
- if not valid_coco_panoptic_annotations(annotations):
- raise ValueError(
- "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
- "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
- "the latter being a list of annotations in the COCO format."
- )
- def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
- unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
- if unused_keys:
- unused_key_str = ", ".join(unused_keys)
- # TODO raise a warning here instead of simply logging?
- logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
|