image_processing_owlv2.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for OWLv2."""
  16. import warnings
  17. from typing import Dict, List, Optional, Tuple, Union
  18. import numpy as np
  19. from ...image_processing_utils import BaseImageProcessor, BatchFeature
  20. from ...image_transforms import (
  21. center_to_corners_format,
  22. pad,
  23. to_channel_dimension_format,
  24. )
  25. from ...image_utils import (
  26. OPENAI_CLIP_MEAN,
  27. OPENAI_CLIP_STD,
  28. ChannelDimension,
  29. ImageInput,
  30. PILImageResampling,
  31. get_image_size,
  32. infer_channel_dimension_format,
  33. is_scaled_image,
  34. make_list_of_images,
  35. to_numpy_array,
  36. valid_images,
  37. validate_preprocess_arguments,
  38. )
  39. from ...utils import (
  40. TensorType,
  41. filter_out_non_signature_kwargs,
  42. is_scipy_available,
  43. is_torch_available,
  44. is_vision_available,
  45. logging,
  46. requires_backends,
  47. )
  48. if is_torch_available():
  49. import torch
  50. if is_vision_available():
  51. import PIL
  52. if is_scipy_available():
  53. from scipy import ndimage as ndi
  54. logger = logging.get_logger(__name__)
  55. # Copied from transformers.models.owlvit.image_processing_owlvit._upcast
  56. def _upcast(t):
  57. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  58. if t.is_floating_point():
  59. return t if t.dtype in (torch.float32, torch.float64) else t.float()
  60. else:
  61. return t if t.dtype in (torch.int32, torch.int64) else t.int()
  62. # Copied from transformers.models.owlvit.image_processing_owlvit.box_area
  63. def box_area(boxes):
  64. """
  65. Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
  66. Args:
  67. boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
  68. Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
  69. < x2` and `0 <= y1 < y2`.
  70. Returns:
  71. `torch.FloatTensor`: a tensor containing the area for each box.
  72. """
  73. boxes = _upcast(boxes)
  74. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  75. # Copied from transformers.models.owlvit.image_processing_owlvit.box_iou
  76. def box_iou(boxes1, boxes2):
  77. area1 = box_area(boxes1)
  78. area2 = box_area(boxes2)
  79. left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  80. right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  81. width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
  82. inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
  83. union = area1[:, None] + area2 - inter
  84. iou = inter / union
  85. return iou, union
  86. def _preprocess_resize_output_shape(image, output_shape):
  87. """Validate resize output shape according to input image.
  88. Args:
  89. image (`np.ndarray`):
  90. Image to be resized.
  91. output_shape (`iterable`):
  92. Size of the generated output image `(rows, cols[, ...][, dim])`. If `dim` is not provided, the number of
  93. channels is preserved.
  94. Returns
  95. image (`np.ndarray`):
  96. The input image, but with additional singleton dimensions appended in the case where `len(output_shape) >
  97. input.ndim`.
  98. output_shape (`Tuple`):
  99. The output shape converted to tuple.
  100. Raises ------ ValueError:
  101. If output_shape length is smaller than the image number of dimensions.
  102. Notes ----- The input image is reshaped if its number of dimensions is not equal to output_shape_length.
  103. """
  104. output_shape = tuple(output_shape)
  105. output_ndim = len(output_shape)
  106. input_shape = image.shape
  107. if output_ndim > image.ndim:
  108. # append dimensions to input_shape
  109. input_shape += (1,) * (output_ndim - image.ndim)
  110. image = np.reshape(image, input_shape)
  111. elif output_ndim == image.ndim - 1:
  112. # multichannel case: append shape of last axis
  113. output_shape = output_shape + (image.shape[-1],)
  114. elif output_ndim < image.ndim:
  115. raise ValueError("output_shape length cannot be smaller than the " "image number of dimensions")
  116. return image, output_shape
  117. def _clip_warp_output(input_image, output_image):
  118. """Clip output image to range of values of input image.
  119. Note that this function modifies the values of *output_image* in-place.
  120. Taken from:
  121. https://github.com/scikit-image/scikit-image/blob/b4b521d6f0a105aabeaa31699949f78453ca3511/skimage/transform/_warps.py#L640.
  122. Args:
  123. input_image : ndarray
  124. Input image.
  125. output_image : ndarray
  126. Output image, which is modified in-place.
  127. """
  128. min_val = np.min(input_image)
  129. if np.isnan(min_val):
  130. # NaNs detected, use NaN-safe min/max
  131. min_func = np.nanmin
  132. max_func = np.nanmax
  133. min_val = min_func(input_image)
  134. else:
  135. min_func = np.min
  136. max_func = np.max
  137. max_val = max_func(input_image)
  138. output_image = np.clip(output_image, min_val, max_val)
  139. return output_image
  140. class Owlv2ImageProcessor(BaseImageProcessor):
  141. r"""
  142. Constructs an OWLv2 image processor.
  143. Args:
  144. do_rescale (`bool`, *optional*, defaults to `True`):
  145. Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in
  146. the `preprocess` method.
  147. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  148. Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
  149. method.
  150. do_pad (`bool`, *optional*, defaults to `True`):
  151. Whether to pad the image to a square with gray pixels on the bottom and the right. Can be overriden by
  152. `do_pad` in the `preprocess` method.
  153. do_resize (`bool`, *optional*, defaults to `True`):
  154. Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden
  155. by `do_resize` in the `preprocess` method.
  156. size (`Dict[str, int]` *optional*, defaults to `{"height": 960, "width": 960}`):
  157. Size to resize the image to. Can be overriden by `size` in the `preprocess` method.
  158. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
  159. Resampling method to use if resizing the image. Can be overriden by `resample` in the `preprocess` method.
  160. do_normalize (`bool`, *optional*, defaults to `True`):
  161. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
  162. method.
  163. image_mean (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`):
  164. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  165. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
  166. image_std (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_STD`):
  167. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  168. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  169. """
  170. model_input_names = ["pixel_values"]
  171. def __init__(
  172. self,
  173. do_rescale: bool = True,
  174. rescale_factor: Union[int, float] = 1 / 255,
  175. do_pad: bool = True,
  176. do_resize: bool = True,
  177. size: Dict[str, int] = None,
  178. resample: PILImageResampling = PILImageResampling.BILINEAR,
  179. do_normalize: bool = True,
  180. image_mean: Optional[Union[float, List[float]]] = None,
  181. image_std: Optional[Union[float, List[float]]] = None,
  182. **kwargs,
  183. ) -> None:
  184. super().__init__(**kwargs)
  185. self.do_rescale = do_rescale
  186. self.rescale_factor = rescale_factor
  187. self.do_pad = do_pad
  188. self.do_resize = do_resize
  189. self.size = size if size is not None else {"height": 960, "width": 960}
  190. self.resample = resample
  191. self.do_normalize = do_normalize
  192. self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
  193. self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
  194. def pad(
  195. self,
  196. image: np.array,
  197. data_format: Optional[Union[str, ChannelDimension]] = None,
  198. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  199. ):
  200. """
  201. Pad an image to a square with gray pixels on the bottom and the right, as per the original OWLv2
  202. implementation.
  203. Args:
  204. image (`np.ndarray`):
  205. Image to pad.
  206. data_format (`str` or `ChannelDimension`, *optional*):
  207. The channel dimension format of the image. If not provided, it will be the same as the input image.
  208. input_data_format (`ChannelDimension` or `str`, *optional*):
  209. The channel dimension format of the input image. If not provided, it will be inferred from the input
  210. image.
  211. """
  212. height, width = get_image_size(image)
  213. size = max(height, width)
  214. image = pad(
  215. image=image,
  216. padding=((0, size - height), (0, size - width)),
  217. constant_values=0.5,
  218. data_format=data_format,
  219. input_data_format=input_data_format,
  220. )
  221. return image
  222. def resize(
  223. self,
  224. image: np.ndarray,
  225. size: Dict[str, int],
  226. anti_aliasing: bool = True,
  227. anti_aliasing_sigma=None,
  228. data_format: Optional[Union[str, ChannelDimension]] = None,
  229. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  230. **kwargs,
  231. ) -> np.ndarray:
  232. """
  233. Resize an image as per the original implementation.
  234. Args:
  235. image (`np.ndarray`):
  236. Image to resize.
  237. size (`Dict[str, int]`):
  238. Dictionary containing the height and width to resize the image to.
  239. anti_aliasing (`bool`, *optional*, defaults to `True`):
  240. Whether to apply anti-aliasing when downsampling the image.
  241. anti_aliasing_sigma (`float`, *optional*, defaults to `None`):
  242. Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated
  243. automatically.
  244. data_format (`str` or `ChannelDimension`, *optional*):
  245. The channel dimension format of the image. If not provided, it will be the same as the input image.
  246. input_data_format (`ChannelDimension` or `str`, *optional*):
  247. The channel dimension format of the input image. If not provided, it will be inferred from the input
  248. image.
  249. """
  250. requires_backends(self, "scipy")
  251. output_shape = (size["height"], size["width"])
  252. image = to_channel_dimension_format(image, ChannelDimension.LAST)
  253. image, output_shape = _preprocess_resize_output_shape(image, output_shape)
  254. input_shape = image.shape
  255. factors = np.divide(input_shape, output_shape)
  256. # Translate modes used by np.pad to those used by scipy.ndimage
  257. ndi_mode = "mirror"
  258. cval = 0
  259. order = 1
  260. if anti_aliasing:
  261. if anti_aliasing_sigma is None:
  262. anti_aliasing_sigma = np.maximum(0, (factors - 1) / 2)
  263. else:
  264. anti_aliasing_sigma = np.atleast_1d(anti_aliasing_sigma) * np.ones_like(factors)
  265. if np.any(anti_aliasing_sigma < 0):
  266. raise ValueError("Anti-aliasing standard deviation must be " "greater than or equal to zero")
  267. elif np.any((anti_aliasing_sigma > 0) & (factors <= 1)):
  268. warnings.warn(
  269. "Anti-aliasing standard deviation greater than zero but " "not down-sampling along all axes"
  270. )
  271. filtered = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, mode=ndi_mode)
  272. else:
  273. filtered = image
  274. zoom_factors = [1 / f for f in factors]
  275. out = ndi.zoom(filtered, zoom_factors, order=order, mode=ndi_mode, cval=cval, grid_mode=True)
  276. image = _clip_warp_output(image, out)
  277. image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST)
  278. image = (
  279. to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
  280. )
  281. return image
  282. @filter_out_non_signature_kwargs()
  283. def preprocess(
  284. self,
  285. images: ImageInput,
  286. do_pad: bool = None,
  287. do_resize: bool = None,
  288. size: Dict[str, int] = None,
  289. do_rescale: bool = None,
  290. rescale_factor: float = None,
  291. do_normalize: bool = None,
  292. image_mean: Optional[Union[float, List[float]]] = None,
  293. image_std: Optional[Union[float, List[float]]] = None,
  294. return_tensors: Optional[Union[str, TensorType]] = None,
  295. data_format: ChannelDimension = ChannelDimension.FIRST,
  296. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  297. ) -> PIL.Image.Image:
  298. """
  299. Preprocess an image or batch of images.
  300. Args:
  301. images (`ImageInput`):
  302. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  303. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  304. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  305. Whether to pad the image to a square with gray pixels on the bottom and the right.
  306. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  307. Whether to resize the image.
  308. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  309. Size to resize the image to.
  310. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  311. Whether to rescale the image values between [0 - 1].
  312. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  313. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  314. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  315. Whether to normalize the image.
  316. image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
  317. Image mean.
  318. image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
  319. Image standard deviation.
  320. return_tensors (`str` or `TensorType`, *optional*):
  321. The type of tensors to return. Can be one of:
  322. - Unset: Return a list of `np.ndarray`.
  323. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  324. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  325. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  326. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  327. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  328. The channel dimension format for the output image. Can be one of:
  329. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  330. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  331. - Unset: Use the channel dimension format of the input image.
  332. input_data_format (`ChannelDimension` or `str`, *optional*):
  333. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  334. from the input image. Can be one of:
  335. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  336. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  337. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  338. """
  339. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  340. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  341. do_pad = do_pad if do_pad is not None else self.do_pad
  342. do_resize = do_resize if do_resize is not None else self.do_resize
  343. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  344. image_mean = image_mean if image_mean is not None else self.image_mean
  345. image_std = image_std if image_std is not None else self.image_std
  346. size = size if size is not None else self.size
  347. images = make_list_of_images(images)
  348. if not valid_images(images):
  349. raise ValueError(
  350. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  351. "torch.Tensor, tf.Tensor or jax.ndarray."
  352. )
  353. # Here, pad and resize methods are different from the rest of image processors
  354. # as they don't have any resampling in resize()
  355. # or pad size in pad() (the maximum of (height, width) is taken instead).
  356. # hence, these arguments don't need to be passed in validate_preprocess_arguments.
  357. validate_preprocess_arguments(
  358. do_rescale=do_rescale,
  359. rescale_factor=rescale_factor,
  360. do_normalize=do_normalize,
  361. image_mean=image_mean,
  362. image_std=image_std,
  363. size=size,
  364. )
  365. # All transformations expect numpy arrays.
  366. images = [to_numpy_array(image) for image in images]
  367. if is_scaled_image(images[0]) and do_rescale:
  368. logger.warning_once(
  369. "It looks like you are trying to rescale already rescaled images. If the input"
  370. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  371. )
  372. if input_data_format is None:
  373. # We assume that all images have the same channel dimension format.
  374. input_data_format = infer_channel_dimension_format(images[0])
  375. if do_rescale:
  376. images = [
  377. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  378. for image in images
  379. ]
  380. if do_pad:
  381. images = [self.pad(image=image, input_data_format=input_data_format) for image in images]
  382. if do_resize:
  383. images = [
  384. self.resize(
  385. image=image,
  386. size=size,
  387. input_data_format=input_data_format,
  388. )
  389. for image in images
  390. ]
  391. if do_normalize:
  392. images = [
  393. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  394. for image in images
  395. ]
  396. images = [
  397. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  398. ]
  399. data = {"pixel_values": images}
  400. return BatchFeature(data=data, tensor_type=return_tensors)
  401. def post_process_object_detection(
  402. self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
  403. ):
  404. """
  405. Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
  406. bottom_right_x, bottom_right_y) format.
  407. Args:
  408. outputs ([`OwlViTObjectDetectionOutput`]):
  409. Raw outputs of the model.
  410. threshold (`float`, *optional*):
  411. Score threshold to keep object detection predictions.
  412. target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
  413. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
  414. `(height, width)` of each image in the batch. If unset, predictions will not be resized.
  415. Returns:
  416. `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  417. in the batch as predicted by the model.
  418. """
  419. # TODO: (amy) add support for other frameworks
  420. logits, boxes = outputs.logits, outputs.pred_boxes
  421. if target_sizes is not None:
  422. if len(logits) != len(target_sizes):
  423. raise ValueError(
  424. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  425. )
  426. probs = torch.max(logits, dim=-1)
  427. scores = torch.sigmoid(probs.values)
  428. labels = probs.indices
  429. # Convert to [x0, y0, x1, y1] format
  430. boxes = center_to_corners_format(boxes)
  431. # Convert from relative [0, 1] to absolute [0, height] coordinates
  432. if target_sizes is not None:
  433. if isinstance(target_sizes, List):
  434. img_h = torch.Tensor([i[0] for i in target_sizes])
  435. img_w = torch.Tensor([i[1] for i in target_sizes])
  436. else:
  437. img_h, img_w = target_sizes.unbind(1)
  438. # Rescale coordinates, image is padded to square for inference,
  439. # that is why we need to scale boxes to the max size
  440. size = torch.max(img_h, img_w)
  441. scale_fct = torch.stack([size, size, size, size], dim=1).to(boxes.device)
  442. boxes = boxes * scale_fct[:, None, :]
  443. results = []
  444. for s, l, b in zip(scores, labels, boxes):
  445. score = s[s > threshold]
  446. label = l[s > threshold]
  447. box = b[s > threshold]
  448. results.append({"scores": score, "labels": label, "boxes": box})
  449. return results
  450. # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection
  451. def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
  452. """
  453. Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
  454. api.
  455. Args:
  456. outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
  457. Raw outputs of the model.
  458. threshold (`float`, *optional*, defaults to 0.0):
  459. Minimum confidence threshold to use to filter out predicted boxes.
  460. nms_threshold (`float`, *optional*, defaults to 0.3):
  461. IoU threshold for non-maximum suppression of overlapping boxes.
  462. target_sizes (`torch.Tensor`, *optional*):
  463. Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
  464. the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
  465. None, predictions will not be unnormalized.
  466. Returns:
  467. `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  468. in the batch as predicted by the model. All labels are set to None as
  469. `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
  470. """
  471. logits, target_boxes = outputs.logits, outputs.target_pred_boxes
  472. if target_sizes is not None and len(logits) != len(target_sizes):
  473. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
  474. if target_sizes is not None and target_sizes.shape[1] != 2:
  475. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  476. probs = torch.max(logits, dim=-1)
  477. scores = torch.sigmoid(probs.values)
  478. # Convert to [x0, y0, x1, y1] format
  479. target_boxes = center_to_corners_format(target_boxes)
  480. # Apply non-maximum suppression (NMS)
  481. if nms_threshold < 1.0:
  482. for idx in range(target_boxes.shape[0]):
  483. for i in torch.argsort(-scores[idx]):
  484. if not scores[idx][i]:
  485. continue
  486. ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
  487. ious[i] = -1.0 # Mask self-IoU.
  488. scores[idx][ious > nms_threshold] = 0.0
  489. # Convert from relative [0, 1] to absolute [0, height] coordinates
  490. if target_sizes is not None:
  491. if isinstance(target_sizes, List):
  492. img_h = torch.tensor([i[0] for i in target_sizes])
  493. img_w = torch.tensor([i[1] for i in target_sizes])
  494. else:
  495. img_h, img_w = target_sizes.unbind(1)
  496. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
  497. target_boxes = target_boxes * scale_fct[:, None, :]
  498. # Compute box display alphas based on prediction scores
  499. results = []
  500. alphas = torch.zeros_like(scores)
  501. for idx in range(target_boxes.shape[0]):
  502. # Select scores for boxes matching the current query:
  503. query_scores = scores[idx]
  504. if not query_scores.nonzero().numel():
  505. continue
  506. # Apply threshold on scores before scaling
  507. query_scores[query_scores < threshold] = 0.0
  508. # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
  509. # All other boxes will either belong to a different query, or will not be shown.
  510. max_score = torch.max(query_scores) + 1e-6
  511. query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
  512. query_alphas = torch.clip(query_alphas, 0.0, 1.0)
  513. alphas[idx] = query_alphas
  514. mask = alphas[idx] > 0
  515. box_scores = alphas[idx][mask]
  516. boxes = target_boxes[idx][mask]
  517. results.append({"scores": box_scores, "labels": None, "boxes": boxes})
  518. return results