image_processing_dpt.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for DPT."""
  16. import math
  17. from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
  18. if TYPE_CHECKING:
  19. from ...modeling_outputs import DepthEstimatorOutput
  20. import numpy as np
  21. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  22. from ...image_transforms import pad, resize, to_channel_dimension_format
  23. from ...image_utils import (
  24. IMAGENET_STANDARD_MEAN,
  25. IMAGENET_STANDARD_STD,
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. get_image_size,
  30. infer_channel_dimension_format,
  31. is_scaled_image,
  32. is_torch_available,
  33. is_torch_tensor,
  34. make_list_of_images,
  35. to_numpy_array,
  36. valid_images,
  37. validate_preprocess_arguments,
  38. )
  39. from ...utils import (
  40. TensorType,
  41. filter_out_non_signature_kwargs,
  42. is_vision_available,
  43. logging,
  44. requires_backends,
  45. )
  46. if is_torch_available():
  47. import torch
  48. if is_vision_available():
  49. import PIL
  50. logger = logging.get_logger(__name__)
  51. def get_resize_output_image_size(
  52. input_image: np.ndarray,
  53. output_size: Union[int, Iterable[int]],
  54. keep_aspect_ratio: bool,
  55. multiple: int,
  56. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  57. ) -> Tuple[int, int]:
  58. def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
  59. x = round(val / multiple) * multiple
  60. if max_val is not None and x > max_val:
  61. x = math.floor(val / multiple) * multiple
  62. if x < min_val:
  63. x = math.ceil(val / multiple) * multiple
  64. return x
  65. output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
  66. input_height, input_width = get_image_size(input_image, input_data_format)
  67. output_height, output_width = output_size
  68. # determine new height and width
  69. scale_height = output_height / input_height
  70. scale_width = output_width / input_width
  71. if keep_aspect_ratio:
  72. # scale as little as possible
  73. if abs(1 - scale_width) < abs(1 - scale_height):
  74. # fit width
  75. scale_height = scale_width
  76. else:
  77. # fit height
  78. scale_width = scale_height
  79. new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
  80. new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
  81. return (new_height, new_width)
  82. class DPTImageProcessor(BaseImageProcessor):
  83. r"""
  84. Constructs a DPT image processor.
  85. Args:
  86. do_resize (`bool`, *optional*, defaults to `True`):
  87. Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
  88. size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
  89. Size of the image after resizing. Can be overidden by `size` in `preprocess`.
  90. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
  91. Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
  92. keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
  93. If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
  94. be overidden by `keep_aspect_ratio` in `preprocess`.
  95. ensure_multiple_of (`int`, *optional*, defaults to 1):
  96. If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
  97. by `ensure_multiple_of` in `preprocess`.
  98. do_rescale (`bool`, *optional*, defaults to `True`):
  99. Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
  100. `preprocess`.
  101. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  102. Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
  103. do_normalize (`bool`, *optional*, defaults to `True`):
  104. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
  105. method.
  106. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
  107. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  108. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  109. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
  110. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  111. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  112. do_pad (`bool`, *optional*, defaults to `False`):
  113. Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
  114. combination with DPT.
  115. size_divisor (`int`, *optional*):
  116. If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
  117. DINOv2 paper, which uses the model in combination with DPT.
  118. """
  119. model_input_names = ["pixel_values"]
  120. def __init__(
  121. self,
  122. do_resize: bool = True,
  123. size: Dict[str, int] = None,
  124. resample: PILImageResampling = PILImageResampling.BICUBIC,
  125. keep_aspect_ratio: bool = False,
  126. ensure_multiple_of: int = 1,
  127. do_rescale: bool = True,
  128. rescale_factor: Union[int, float] = 1 / 255,
  129. do_normalize: bool = True,
  130. image_mean: Optional[Union[float, List[float]]] = None,
  131. image_std: Optional[Union[float, List[float]]] = None,
  132. do_pad: bool = False,
  133. size_divisor: int = None,
  134. **kwargs,
  135. ) -> None:
  136. super().__init__(**kwargs)
  137. size = size if size is not None else {"height": 384, "width": 384}
  138. size = get_size_dict(size)
  139. self.do_resize = do_resize
  140. self.size = size
  141. self.keep_aspect_ratio = keep_aspect_ratio
  142. self.ensure_multiple_of = ensure_multiple_of
  143. self.resample = resample
  144. self.do_rescale = do_rescale
  145. self.rescale_factor = rescale_factor
  146. self.do_normalize = do_normalize
  147. self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
  148. self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
  149. self.do_pad = do_pad
  150. self.size_divisor = size_divisor
  151. def resize(
  152. self,
  153. image: np.ndarray,
  154. size: Dict[str, int],
  155. keep_aspect_ratio: bool = False,
  156. ensure_multiple_of: int = 1,
  157. resample: PILImageResampling = PILImageResampling.BICUBIC,
  158. data_format: Optional[Union[str, ChannelDimension]] = None,
  159. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  160. **kwargs,
  161. ) -> np.ndarray:
  162. """
  163. Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
  164. is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
  165. set, the image is resized to a size that is a multiple of this value.
  166. Args:
  167. image (`np.ndarray`):
  168. Image to resize.
  169. size (`Dict[str, int]`):
  170. Target size of the output image.
  171. keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
  172. If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
  173. ensure_multiple_of (`int`, *optional*, defaults to 1):
  174. The image is resized to a size that is a multiple of this value.
  175. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  176. Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
  177. specified in `size`.
  178. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  179. Resampling filter to use when resiizing the image.
  180. data_format (`str` or `ChannelDimension`, *optional*):
  181. The channel dimension format of the image. If not provided, it will be the same as the input image.
  182. input_data_format (`str` or `ChannelDimension`, *optional*):
  183. The channel dimension format of the input image. If not provided, it will be inferred.
  184. """
  185. size = get_size_dict(size)
  186. if "height" not in size or "width" not in size:
  187. raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
  188. output_size = get_resize_output_image_size(
  189. image,
  190. output_size=(size["height"], size["width"]),
  191. keep_aspect_ratio=keep_aspect_ratio,
  192. multiple=ensure_multiple_of,
  193. input_data_format=input_data_format,
  194. )
  195. return resize(
  196. image,
  197. size=output_size,
  198. resample=resample,
  199. data_format=data_format,
  200. input_data_format=input_data_format,
  201. **kwargs,
  202. )
  203. def pad_image(
  204. self,
  205. image: np.array,
  206. size_divisor: int,
  207. data_format: Optional[Union[str, ChannelDimension]] = None,
  208. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  209. ):
  210. """
  211. Center pad an image to be a multiple of `multiple`.
  212. Args:
  213. image (`np.ndarray`):
  214. Image to pad.
  215. size_divisor (`int`):
  216. The width and height of the image will be padded to a multiple of this number.
  217. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  218. The channel dimension format for the output image. Can be one of:
  219. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  220. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  221. - Unset: Use the channel dimension format of the input image.
  222. input_data_format (`ChannelDimension` or `str`, *optional*):
  223. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  224. from the input image. Can be one of:
  225. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  226. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  227. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  228. """
  229. def _get_pad(size, size_divisor):
  230. new_size = math.ceil(size / size_divisor) * size_divisor
  231. pad_size = new_size - size
  232. pad_size_left = pad_size // 2
  233. pad_size_right = pad_size - pad_size_left
  234. return pad_size_left, pad_size_right
  235. if input_data_format is None:
  236. input_data_format = infer_channel_dimension_format(image)
  237. height, width = get_image_size(image, input_data_format)
  238. pad_size_left, pad_size_right = _get_pad(height, size_divisor)
  239. pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
  240. return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format)
  241. @filter_out_non_signature_kwargs()
  242. def preprocess(
  243. self,
  244. images: ImageInput,
  245. do_resize: bool = None,
  246. size: int = None,
  247. keep_aspect_ratio: bool = None,
  248. ensure_multiple_of: int = None,
  249. resample: PILImageResampling = None,
  250. do_rescale: bool = None,
  251. rescale_factor: float = None,
  252. do_normalize: bool = None,
  253. image_mean: Optional[Union[float, List[float]]] = None,
  254. image_std: Optional[Union[float, List[float]]] = None,
  255. do_pad: bool = None,
  256. size_divisor: int = None,
  257. return_tensors: Optional[Union[str, TensorType]] = None,
  258. data_format: ChannelDimension = ChannelDimension.FIRST,
  259. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  260. ) -> PIL.Image.Image:
  261. """
  262. Preprocess an image or batch of images.
  263. Args:
  264. images (`ImageInput`):
  265. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  266. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  267. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  268. Whether to resize the image.
  269. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  270. Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
  271. possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
  272. resized to a size that is a multiple of this value.
  273. keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
  274. Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
  275. True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
  276. ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
  277. Ensure that the image size is a multiple of this value.
  278. resample (`int`, *optional*, defaults to `self.resample`):
  279. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
  280. has an effect if `do_resize` is set to `True`.
  281. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  282. Whether to rescale the image values between [0 - 1].
  283. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  284. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  285. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  286. Whether to normalize the image.
  287. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  288. Image mean.
  289. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  290. Image standard deviation.
  291. return_tensors (`str` or `TensorType`, *optional*):
  292. The type of tensors to return. Can be one of:
  293. - Unset: Return a list of `np.ndarray`.
  294. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  295. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  296. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  297. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  298. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  299. The channel dimension format for the output image. Can be one of:
  300. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  301. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  302. input_data_format (`ChannelDimension` or `str`, *optional*):
  303. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  304. from the input image. Can be one of:
  305. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  306. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  307. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  308. """
  309. do_resize = do_resize if do_resize is not None else self.do_resize
  310. size = size if size is not None else self.size
  311. size = get_size_dict(size)
  312. keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
  313. ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
  314. resample = resample if resample is not None else self.resample
  315. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  316. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  317. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  318. image_mean = image_mean if image_mean is not None else self.image_mean
  319. image_std = image_std if image_std is not None else self.image_std
  320. do_pad = do_pad if do_pad is not None else self.do_pad
  321. size_divisor = size_divisor if size_divisor is not None else self.size_divisor
  322. images = make_list_of_images(images)
  323. if not valid_images(images):
  324. raise ValueError(
  325. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  326. "torch.Tensor, tf.Tensor or jax.ndarray."
  327. )
  328. validate_preprocess_arguments(
  329. do_rescale=do_rescale,
  330. rescale_factor=rescale_factor,
  331. do_normalize=do_normalize,
  332. image_mean=image_mean,
  333. image_std=image_std,
  334. do_pad=do_pad,
  335. size_divisibility=size_divisor,
  336. do_resize=do_resize,
  337. size=size,
  338. resample=resample,
  339. )
  340. # All transformations expect numpy arrays.
  341. images = [to_numpy_array(image) for image in images]
  342. if is_scaled_image(images[0]) and do_rescale:
  343. logger.warning_once(
  344. "It looks like you are trying to rescale already rescaled images. If the input"
  345. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  346. )
  347. if input_data_format is None:
  348. # We assume that all images have the same channel dimension format.
  349. input_data_format = infer_channel_dimension_format(images[0])
  350. if do_resize:
  351. images = [
  352. self.resize(
  353. image=image,
  354. size=size,
  355. resample=resample,
  356. keep_aspect_ratio=keep_aspect_ratio,
  357. ensure_multiple_of=ensure_multiple_of,
  358. input_data_format=input_data_format,
  359. )
  360. for image in images
  361. ]
  362. if do_rescale:
  363. images = [
  364. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  365. for image in images
  366. ]
  367. if do_normalize:
  368. images = [
  369. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  370. for image in images
  371. ]
  372. if do_pad:
  373. images = [
  374. self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
  375. for image in images
  376. ]
  377. images = [
  378. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  379. ]
  380. data = {"pixel_values": images}
  381. return BatchFeature(data=data, tensor_type=return_tensors)
  382. # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT
  383. def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
  384. """
  385. Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
  386. Args:
  387. outputs ([`DPTForSemanticSegmentation`]):
  388. Raw outputs of the model.
  389. target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
  390. List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
  391. predictions will not be resized.
  392. Returns:
  393. semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
  394. segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
  395. specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
  396. """
  397. # TODO: add support for other frameworks
  398. logits = outputs.logits
  399. # Resize logits and compute semantic segmentation maps
  400. if target_sizes is not None:
  401. if len(logits) != len(target_sizes):
  402. raise ValueError(
  403. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  404. )
  405. if is_torch_tensor(target_sizes):
  406. target_sizes = target_sizes.numpy()
  407. semantic_segmentation = []
  408. for idx in range(len(logits)):
  409. resized_logits = torch.nn.functional.interpolate(
  410. logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
  411. )
  412. semantic_map = resized_logits[0].argmax(dim=0)
  413. semantic_segmentation.append(semantic_map)
  414. else:
  415. semantic_segmentation = logits.argmax(dim=1)
  416. semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
  417. return semantic_segmentation
  418. def post_process_depth_estimation(
  419. self,
  420. outputs: "DepthEstimatorOutput",
  421. target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
  422. ) -> List[Dict[str, TensorType]]:
  423. """
  424. Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
  425. Only supports PyTorch.
  426. Args:
  427. outputs ([`DepthEstimatorOutput`]):
  428. Raw outputs of the model.
  429. target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
  430. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
  431. (height, width) of each image in the batch. If left to None, predictions will not be resized.
  432. Returns:
  433. `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
  434. predictions.
  435. """
  436. requires_backends(self, "torch")
  437. predicted_depth = outputs.predicted_depth
  438. if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
  439. raise ValueError(
  440. "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
  441. )
  442. results = []
  443. target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
  444. for depth, target_size in zip(predicted_depth, target_sizes):
  445. if target_size is not None:
  446. depth = torch.nn.functional.interpolate(
  447. depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
  448. ).squeeze()
  449. results.append({"predicted_depth": depth})
  450. return results