image_processing_zoedepth.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for ZoeDepth."""
  16. import math
  17. from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
  18. import numpy as np
  19. if TYPE_CHECKING:
  20. from .modeling_zoedepth import ZoeDepthDepthEstimatorOutput
  21. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  22. from ...image_transforms import PaddingMode, pad, to_channel_dimension_format
  23. from ...image_utils import (
  24. IMAGENET_STANDARD_MEAN,
  25. IMAGENET_STANDARD_STD,
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. get_image_size,
  30. infer_channel_dimension_format,
  31. is_scaled_image,
  32. make_list_of_images,
  33. to_numpy_array,
  34. valid_images,
  35. validate_preprocess_arguments,
  36. )
  37. from ...utils import (
  38. TensorType,
  39. filter_out_non_signature_kwargs,
  40. is_torch_available,
  41. is_vision_available,
  42. logging,
  43. requires_backends,
  44. )
  45. if is_vision_available():
  46. import PIL
  47. if is_torch_available():
  48. import torch
  49. from torch import nn
  50. logger = logging.get_logger(__name__)
  51. def get_resize_output_image_size(
  52. input_image: np.ndarray,
  53. output_size: Union[int, Iterable[int]],
  54. keep_aspect_ratio: bool,
  55. multiple: int,
  56. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  57. ) -> Tuple[int, int]:
  58. def constrain_to_multiple_of(val, multiple, min_val=0):
  59. x = (np.round(val / multiple) * multiple).astype(int)
  60. if x < min_val:
  61. x = math.ceil(val / multiple) * multiple
  62. return x
  63. output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
  64. input_height, input_width = get_image_size(input_image, input_data_format)
  65. output_height, output_width = output_size
  66. # determine new height and width
  67. scale_height = output_height / input_height
  68. scale_width = output_width / input_width
  69. if keep_aspect_ratio:
  70. # scale as little as possible
  71. if abs(1 - scale_width) < abs(1 - scale_height):
  72. # fit width
  73. scale_height = scale_width
  74. else:
  75. # fit height
  76. scale_width = scale_height
  77. new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
  78. new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
  79. return (new_height, new_width)
  80. class ZoeDepthImageProcessor(BaseImageProcessor):
  81. r"""
  82. Constructs a ZoeDepth image processor.
  83. Args:
  84. do_pad (`bool`, *optional*, defaults to `True`):
  85. Whether to apply pad the input.
  86. do_rescale (`bool`, *optional*, defaults to `True`):
  87. Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
  88. `preprocess`.
  89. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  90. Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
  91. do_normalize (`bool`, *optional*, defaults to `True`):
  92. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
  93. method.
  94. image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
  95. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  96. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  97. image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
  98. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  99. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  100. do_resize (`bool`, *optional*, defaults to `True`):
  101. Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
  102. size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 512}`):
  103. Size of the image after resizing. Size of the image after resizing. If `keep_aspect_ratio` is `True`,
  104. the image is resized by choosing the smaller of the height and width scaling factors and using it for both dimensions.
  105. If `ensure_multiple_of` is also set, the image is further resized to a size that is a multiple of this value.
  106. Can be overidden by `size` in `preprocess`.
  107. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
  108. Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
  109. keep_aspect_ratio (`bool`, *optional*, defaults to `True`):
  110. If `True`, the image is resized by choosing the smaller of the height and width scaling factors and using it
  111. for both dimensions. This ensures that the image is scaled down as little as possible while still fitting
  112. within the desired output size. In case `ensure_multiple_of` is also set, the image is further resized to a
  113. size that is a multiple of this value by flooring the height and width to the nearest multiple of this value.
  114. Can be overidden by `keep_aspect_ratio` in `preprocess`.
  115. ensure_multiple_of (`int`, *optional*, defaults to 32):
  116. If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by flooring
  117. the height and width to the nearest multiple of this value.
  118. Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by `ensure_multiple_of`
  119. in `preprocess`.
  120. """
  121. model_input_names = ["pixel_values"]
  122. def __init__(
  123. self,
  124. do_pad: bool = True,
  125. do_rescale: bool = True,
  126. rescale_factor: Union[int, float] = 1 / 255,
  127. do_normalize: bool = True,
  128. image_mean: Optional[Union[float, List[float]]] = None,
  129. image_std: Optional[Union[float, List[float]]] = None,
  130. do_resize: bool = True,
  131. size: Dict[str, int] = None,
  132. resample: PILImageResampling = PILImageResampling.BILINEAR,
  133. keep_aspect_ratio: bool = True,
  134. ensure_multiple_of: int = 32,
  135. **kwargs,
  136. ) -> None:
  137. super().__init__(**kwargs)
  138. self.do_rescale = do_rescale
  139. self.rescale_factor = rescale_factor
  140. self.do_pad = do_pad
  141. self.do_normalize = do_normalize
  142. self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
  143. self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
  144. size = size if size is not None else {"height": 384, "width": 512}
  145. size = get_size_dict(size)
  146. self.do_resize = do_resize
  147. self.size = size
  148. self.keep_aspect_ratio = keep_aspect_ratio
  149. self.ensure_multiple_of = ensure_multiple_of
  150. self.resample = resample
  151. def resize(
  152. self,
  153. image: np.ndarray,
  154. size: Dict[str, int],
  155. keep_aspect_ratio: bool = False,
  156. ensure_multiple_of: int = 1,
  157. resample: PILImageResampling = PILImageResampling.BILINEAR,
  158. data_format: Optional[Union[str, ChannelDimension]] = None,
  159. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  160. ) -> np.ndarray:
  161. """
  162. Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
  163. is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
  164. set, the image is resized to a size that is a multiple of this value.
  165. Args:
  166. image (`np.ndarray`):
  167. Image to resize.
  168. size (`Dict[str, int]`):
  169. Target size of the output image.
  170. keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
  171. If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
  172. ensure_multiple_of (`int`, *optional*, defaults to 1):
  173. The image is resized to a size that is a multiple of this value.
  174. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
  175. Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
  176. specified in `size`.
  177. data_format (`str` or `ChannelDimension`, *optional*):
  178. The channel dimension format of the image. If not provided, it will be the same as the input image.
  179. input_data_format (`str` or `ChannelDimension`, *optional*):
  180. The channel dimension format of the input image. If not provided, it will be inferred.
  181. """
  182. if input_data_format is None:
  183. input_data_format = infer_channel_dimension_format(image)
  184. data_format = data_format if data_format is not None else input_data_format
  185. size = get_size_dict(size)
  186. if "height" not in size or "width" not in size:
  187. raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
  188. output_size = get_resize_output_image_size(
  189. image,
  190. output_size=(size["height"], size["width"]),
  191. keep_aspect_ratio=keep_aspect_ratio,
  192. multiple=ensure_multiple_of,
  193. input_data_format=input_data_format,
  194. )
  195. height, width = output_size
  196. torch_image = torch.from_numpy(image).unsqueeze(0)
  197. torch_image = torch_image.permute(0, 3, 1, 2) if input_data_format == "channels_last" else torch_image
  198. # TODO support align_corners=True in image_transforms.resize
  199. requires_backends(self, "torch")
  200. resample_to_mode = {PILImageResampling.BILINEAR: "bilinear", PILImageResampling.BICUBIC: "bicubic"}
  201. mode = resample_to_mode[resample]
  202. resized_image = nn.functional.interpolate(
  203. torch_image, (int(height), int(width)), mode=mode, align_corners=True
  204. )
  205. resized_image = resized_image.squeeze().numpy()
  206. resized_image = to_channel_dimension_format(
  207. resized_image, data_format, input_channel_dim=ChannelDimension.FIRST
  208. )
  209. return resized_image
  210. def pad_image(
  211. self,
  212. image: np.array,
  213. mode: PaddingMode = PaddingMode.REFLECT,
  214. data_format: Optional[Union[str, ChannelDimension]] = None,
  215. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  216. ):
  217. """
  218. Pad an image as done in the original ZoeDepth implementation.
  219. Padding fixes the boundary artifacts in the output depth map.
  220. Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset
  221. which has a black or white border around the image. This function pads the input image and crops
  222. the prediction back to the original size / view.
  223. Args:
  224. image (`np.ndarray`):
  225. Image to pad.
  226. mode (`PaddingMode`):
  227. The padding mode to use. Can be one of:
  228. - `"constant"`: pads with a constant value.
  229. - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
  230. vector along each axis.
  231. - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
  232. - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
  233. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  234. The channel dimension format for the output image. Can be one of:
  235. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  236. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  237. - Unset: Use the channel dimension format of the input image.
  238. input_data_format (`ChannelDimension` or `str`, *optional*):
  239. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  240. from the input image. Can be one of:
  241. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  242. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  243. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  244. """
  245. height, width = get_image_size(image, input_data_format)
  246. pad_height = int(np.sqrt(height / 2) * 3)
  247. pad_width = int(np.sqrt(width / 2) * 3)
  248. return pad(
  249. image,
  250. padding=((pad_height, pad_height), (pad_width, pad_width)),
  251. mode=mode,
  252. data_format=data_format,
  253. input_data_format=input_data_format,
  254. )
  255. @filter_out_non_signature_kwargs()
  256. def preprocess(
  257. self,
  258. images: ImageInput,
  259. do_pad: bool = None,
  260. do_rescale: bool = None,
  261. rescale_factor: float = None,
  262. do_normalize: bool = None,
  263. image_mean: Optional[Union[float, List[float]]] = None,
  264. image_std: Optional[Union[float, List[float]]] = None,
  265. do_resize: bool = None,
  266. size: int = None,
  267. keep_aspect_ratio: bool = None,
  268. ensure_multiple_of: int = None,
  269. resample: PILImageResampling = None,
  270. return_tensors: Optional[Union[str, TensorType]] = None,
  271. data_format: ChannelDimension = ChannelDimension.FIRST,
  272. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  273. ) -> PIL.Image.Image:
  274. """
  275. Preprocess an image or batch of images.
  276. Args:
  277. images (`ImageInput`):
  278. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  279. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  280. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  281. Whether to pad the input image.
  282. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  283. Whether to rescale the image values between [0 - 1].
  284. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  285. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  286. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  287. Whether to normalize the image.
  288. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  289. Image mean.
  290. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  291. Image standard deviation.
  292. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  293. Whether to resize the image.
  294. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  295. Size of the image after resizing. If `keep_aspect_ratio` is `True`, he image is resized by choosing the
  296. smaller of the height and width scaling factors and using it for both dimensions. If `ensure_multiple_of`
  297. is also set, the image is further resized to a size that is a multiple of this value.
  298. keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
  299. If `True` and `do_resize=True`, the image is resized by choosing the smaller of the height and width
  300. scaling factors and using it for both dimensions. This ensures that the image is scaled down as little
  301. as possible while still fitting within the desired output size. In case `ensure_multiple_of` is also
  302. set, the image is further resized to a size that is a multiple of this value by flooring the height and
  303. width to the nearest multiple of this value.
  304. ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
  305. If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Works by
  306. flooring the height and width to the nearest multiple of this value.
  307. Works both with and without `keep_aspect_ratio` being set to `True`. Can be overidden by
  308. `ensure_multiple_of` in `preprocess`.
  309. resample (`int`, *optional*, defaults to `self.resample`):
  310. Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
  311. has an effect if `do_resize` is set to `True`.
  312. return_tensors (`str` or `TensorType`, *optional*):
  313. The type of tensors to return. Can be one of:
  314. - Unset: Return a list of `np.ndarray`.
  315. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  316. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  317. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  318. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  319. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  320. The channel dimension format for the output image. Can be one of:
  321. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  322. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  323. input_data_format (`ChannelDimension` or `str`, *optional*):
  324. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  325. from the input image. Can be one of:
  326. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  327. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  328. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  329. """
  330. do_resize = do_resize if do_resize is not None else self.do_resize
  331. size = size if size is not None else self.size
  332. size = get_size_dict(size)
  333. keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
  334. ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
  335. resample = resample if resample is not None else self.resample
  336. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  337. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  338. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  339. image_mean = image_mean if image_mean is not None else self.image_mean
  340. image_std = image_std if image_std is not None else self.image_std
  341. do_pad = do_pad if do_pad is not None else self.do_pad
  342. images = make_list_of_images(images)
  343. if not valid_images(images):
  344. raise ValueError(
  345. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  346. "torch.Tensor, tf.Tensor or jax.ndarray."
  347. )
  348. validate_preprocess_arguments(
  349. do_rescale=do_rescale,
  350. rescale_factor=rescale_factor,
  351. do_normalize=do_normalize,
  352. image_mean=image_mean,
  353. image_std=image_std,
  354. do_resize=do_resize,
  355. size=size,
  356. resample=resample,
  357. )
  358. # All transformations expect numpy arrays.
  359. images = [to_numpy_array(image) for image in images]
  360. if is_scaled_image(images[0]) and do_rescale:
  361. logger.warning_once(
  362. "It looks like you are trying to rescale already rescaled images. If the input"
  363. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  364. )
  365. if input_data_format is None:
  366. # We assume that all images have the same channel dimension format.
  367. input_data_format = infer_channel_dimension_format(images[0])
  368. if do_rescale:
  369. images = [
  370. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  371. for image in images
  372. ]
  373. if do_pad:
  374. images = [self.pad_image(image=image, input_data_format=input_data_format) for image in images]
  375. if do_resize:
  376. images = [
  377. self.resize(
  378. image=image,
  379. size=size,
  380. resample=resample,
  381. keep_aspect_ratio=keep_aspect_ratio,
  382. ensure_multiple_of=ensure_multiple_of,
  383. input_data_format=input_data_format,
  384. )
  385. for image in images
  386. ]
  387. if do_normalize:
  388. images = [
  389. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  390. for image in images
  391. ]
  392. images = [
  393. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  394. ]
  395. data = {"pixel_values": images}
  396. return BatchFeature(data=data, tensor_type=return_tensors)
  397. def post_process_depth_estimation(
  398. self,
  399. outputs: "ZoeDepthDepthEstimatorOutput",
  400. source_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
  401. target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
  402. outputs_flipped: Optional[Union["ZoeDepthDepthEstimatorOutput", None]] = None,
  403. do_remove_padding: Optional[Union[bool, None]] = None,
  404. ) -> List[Dict[str, TensorType]]:
  405. """
  406. Converts the raw output of [`ZoeDepthDepthEstimatorOutput`] into final depth predictions and depth PIL images.
  407. Only supports PyTorch.
  408. Args:
  409. outputs ([`ZoeDepthDepthEstimatorOutput`]):
  410. Raw outputs of the model.
  411. source_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
  412. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the source size
  413. (height, width) of each image in the batch before preprocessing. This argument should be dealt as
  414. "required" unless the user passes `do_remove_padding=False` as input to this function.
  415. target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
  416. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
  417. (height, width) of each image in the batch. If left to None, predictions will not be resized.
  418. outputs_flipped ([`ZoeDepthDepthEstimatorOutput`], *optional*):
  419. Raw outputs of the model from flipped input (averaged out in the end).
  420. do_remove_padding (`bool`, *optional*):
  421. By default ZoeDepth addes padding equal to `int(√(height / 2) * 3)` (and similarly for width) to fix the
  422. boundary artifacts in the output depth map, so we need remove this padding during post_processing. The
  423. parameter exists here in case the user changed the image preprocessing to not include padding.
  424. Returns:
  425. `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
  426. predictions.
  427. """
  428. requires_backends(self, "torch")
  429. predicted_depth = outputs.predicted_depth
  430. if (outputs_flipped is not None) and (predicted_depth.shape != outputs_flipped.predicted_depth.shape):
  431. raise ValueError("Make sure that `outputs` and `outputs_flipped` have the same shape")
  432. if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
  433. raise ValueError(
  434. "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
  435. )
  436. if do_remove_padding is None:
  437. do_remove_padding = self.do_pad
  438. if source_sizes is None and do_remove_padding:
  439. raise ValueError(
  440. "Either `source_sizes` should be passed in, or `do_remove_padding` should be set to False"
  441. )
  442. if (source_sizes is not None) and (len(predicted_depth) != len(source_sizes)):
  443. raise ValueError(
  444. "Make sure that you pass in as many source image sizes as the batch dimension of the logits"
  445. )
  446. if outputs_flipped is not None:
  447. predicted_depth = (predicted_depth + torch.flip(outputs_flipped.predicted_depth, dims=[-1])) / 2
  448. predicted_depth = predicted_depth.unsqueeze(1)
  449. # Zoe Depth model adds padding around the images to fix the boundary artifacts in the output depth map
  450. # The padding length is `int(np.sqrt(img_h/2) * fh)` for the height and similar for the width
  451. # fh (and fw respectively) are equal to '3' by default
  452. # Check [here](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L57)
  453. # for the original implementation.
  454. # In this section, we remove this padding to get the final depth image and depth prediction
  455. padding_factor_h = padding_factor_w = 3
  456. results = []
  457. target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
  458. source_sizes = [None] * len(predicted_depth) if source_sizes is None else source_sizes
  459. for depth, target_size, source_size in zip(predicted_depth, target_sizes, source_sizes):
  460. # depth.shape = [1, H, W]
  461. if source_size is not None:
  462. pad_h = pad_w = 0
  463. if do_remove_padding:
  464. pad_h = int(np.sqrt(source_size[0] / 2) * padding_factor_h)
  465. pad_w = int(np.sqrt(source_size[1] / 2) * padding_factor_w)
  466. depth = nn.functional.interpolate(
  467. depth.unsqueeze(1),
  468. size=[source_size[0] + 2 * pad_h, source_size[1] + 2 * pad_w],
  469. mode="bicubic",
  470. align_corners=False,
  471. )
  472. if pad_h > 0:
  473. depth = depth[:, :, pad_h:-pad_h, :]
  474. if pad_w > 0:
  475. depth = depth[:, :, :, pad_w:-pad_w]
  476. depth = depth.squeeze(1)
  477. # depth.shape = [1, H, W]
  478. if target_size is not None:
  479. target_size = [target_size[0], target_size[1]]
  480. depth = nn.functional.interpolate(
  481. depth.unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
  482. )
  483. depth = depth.squeeze()
  484. # depth.shape = [H, W]
  485. results.append({"predicted_depth": depth})
  486. return results