image_processing_owlvit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Image processor class for OwlViT"""
  16. import warnings
  17. from typing import Dict, List, Optional, Tuple, Union
  18. import numpy as np
  19. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
  20. from ...image_transforms import (
  21. center_crop,
  22. center_to_corners_format,
  23. rescale,
  24. resize,
  25. to_channel_dimension_format,
  26. )
  27. from ...image_utils import (
  28. OPENAI_CLIP_MEAN,
  29. OPENAI_CLIP_STD,
  30. ChannelDimension,
  31. ImageInput,
  32. PILImageResampling,
  33. infer_channel_dimension_format,
  34. is_scaled_image,
  35. make_list_of_images,
  36. to_numpy_array,
  37. valid_images,
  38. validate_preprocess_arguments,
  39. )
  40. from ...utils import TensorType, filter_out_non_signature_kwargs, is_torch_available, logging
  41. if is_torch_available():
  42. import torch
  43. logger = logging.get_logger(__name__)
  44. def _upcast(t):
  45. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  46. if t.is_floating_point():
  47. return t if t.dtype in (torch.float32, torch.float64) else t.float()
  48. else:
  49. return t if t.dtype in (torch.int32, torch.int64) else t.int()
  50. def box_area(boxes):
  51. """
  52. Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
  53. Args:
  54. boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
  55. Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
  56. < x2` and `0 <= y1 < y2`.
  57. Returns:
  58. `torch.FloatTensor`: a tensor containing the area for each box.
  59. """
  60. boxes = _upcast(boxes)
  61. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  62. def box_iou(boxes1, boxes2):
  63. area1 = box_area(boxes1)
  64. area2 = box_area(boxes2)
  65. left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  66. right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  67. width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
  68. inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
  69. union = area1[:, None] + area2 - inter
  70. iou = inter / union
  71. return iou, union
  72. class OwlViTImageProcessor(BaseImageProcessor):
  73. r"""
  74. Constructs an OWL-ViT image processor.
  75. This image processor inherits from [`ImageProcessingMixin`] which contains most of the main methods. Users should
  76. refer to this superclass for more information regarding those methods.
  77. Args:
  78. do_resize (`bool`, *optional*, defaults to `True`):
  79. Whether to resize the shorter edge of the input to a certain `size`.
  80. size (`Dict[str, int]`, *optional*, defaults to {"height": 768, "width": 768}):
  81. The size to use for resizing the image. Only has an effect if `do_resize` is set to `True`. If `size` is a
  82. sequence like (h, w), output size will be matched to this. If `size` is an int, then image will be resized
  83. to (size, size).
  84. resample (`int`, *optional*, defaults to `Resampling.BICUBIC`):
  85. An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
  86. `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
  87. `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
  88. to `True`.
  89. do_center_crop (`bool`, *optional*, defaults to `False`):
  90. Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge, the
  91. image is padded with 0's and then center cropped.
  92. crop_size (`int`, *optional*, defaults to {"height": 768, "width": 768}):
  93. The size to use for center cropping the image. Only has an effect if `do_center_crop` is set to `True`.
  94. do_rescale (`bool`, *optional*, defaults to `True`):
  95. Whether to rescale the input by a certain factor.
  96. rescale_factor (`float`, *optional*, defaults to `1/255`):
  97. The factor to use for rescaling the image. Only has an effect if `do_rescale` is set to `True`.
  98. do_normalize (`bool`, *optional*, defaults to `True`):
  99. Whether or not to normalize the input with `image_mean` and `image_std`. Desired output size when applying
  100. center-cropping. Only has an effect if `do_center_crop` is set to `True`.
  101. image_mean (`List[int]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
  102. The sequence of means for each channel, to be used when normalizing images.
  103. image_std (`List[int]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
  104. The sequence of standard deviations for each channel, to be used when normalizing images.
  105. """
  106. model_input_names = ["pixel_values"]
  107. def __init__(
  108. self,
  109. do_resize=True,
  110. size=None,
  111. resample=PILImageResampling.BICUBIC,
  112. do_center_crop=False,
  113. crop_size=None,
  114. do_rescale=True,
  115. rescale_factor=1 / 255,
  116. do_normalize=True,
  117. image_mean=None,
  118. image_std=None,
  119. **kwargs,
  120. ):
  121. size = size if size is not None else {"height": 768, "width": 768}
  122. size = get_size_dict(size, default_to_square=True)
  123. crop_size = crop_size if crop_size is not None else {"height": 768, "width": 768}
  124. crop_size = get_size_dict(crop_size, default_to_square=True)
  125. # Early versions of the OWL-ViT config on the hub had "rescale" as a flag. This clashes with the
  126. # vision image processor method `rescale` as it would be set as an attribute during the super().__init__
  127. # call. This is for backwards compatibility.
  128. if "rescale" in kwargs:
  129. rescale_val = kwargs.pop("rescale")
  130. kwargs["do_rescale"] = rescale_val
  131. super().__init__(**kwargs)
  132. self.do_resize = do_resize
  133. self.size = size
  134. self.resample = resample
  135. self.do_center_crop = do_center_crop
  136. self.crop_size = crop_size
  137. self.do_rescale = do_rescale
  138. self.rescale_factor = rescale_factor
  139. self.do_normalize = do_normalize
  140. self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
  141. self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
  142. def resize(
  143. self,
  144. image: np.ndarray,
  145. size: Dict[str, int],
  146. resample: PILImageResampling.BICUBIC,
  147. data_format: Optional[Union[str, ChannelDimension]] = None,
  148. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  149. **kwargs,
  150. ) -> np.ndarray:
  151. """
  152. Resize an image to a certain size.
  153. Args:
  154. image (`np.ndarray`):
  155. Image to resize.
  156. size (`Dict[str, int]`):
  157. The size to resize the image to. Must contain height and width keys.
  158. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  159. The resampling filter to use when resizing the input.
  160. data_format (`str` or `ChannelDimension`, *optional*):
  161. The channel dimension format for the output image. If unset, the channel dimension format of the input
  162. image is used.
  163. input_data_format (`str` or `ChannelDimension`, *optional*):
  164. The channel dimension format of the input image. If not provided, it will be inferred.
  165. """
  166. size = get_size_dict(size, default_to_square=True)
  167. if "height" not in size or "width" not in size:
  168. raise ValueError("size dictionary must contain height and width keys")
  169. return resize(
  170. image,
  171. (size["height"], size["width"]),
  172. resample=resample,
  173. data_format=data_format,
  174. input_data_format=input_data_format,
  175. **kwargs,
  176. )
  177. def center_crop(
  178. self,
  179. image: np.ndarray,
  180. crop_size: Dict[str, int],
  181. data_format: Optional[Union[str, ChannelDimension]] = None,
  182. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  183. **kwargs,
  184. ) -> np.ndarray:
  185. """
  186. Center crop an image to a certain size.
  187. Args:
  188. image (`np.ndarray`):
  189. Image to center crop.
  190. crop_size (`Dict[str, int]`):
  191. The size to center crop the image to. Must contain height and width keys.
  192. data_format (`str` or `ChannelDimension`, *optional*):
  193. The channel dimension format for the output image. If unset, the channel dimension format of the input
  194. image is used.
  195. input_data_format (`str` or `ChannelDimension`, *optional*):
  196. The channel dimension format of the input image. If not provided, it will be inferred.
  197. """
  198. crop_size = get_size_dict(crop_size, default_to_square=True)
  199. if "height" not in crop_size or "width" not in crop_size:
  200. raise ValueError("crop_size dictionary must contain height and width keys")
  201. return center_crop(
  202. image,
  203. (crop_size["height"], crop_size["width"]),
  204. data_format=data_format,
  205. input_data_format=input_data_format,
  206. **kwargs,
  207. )
  208. # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
  209. def rescale(
  210. self,
  211. image: np.ndarray,
  212. rescale_factor: float,
  213. data_format: Optional[Union[str, ChannelDimension]] = None,
  214. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  215. ) -> np.ndarray:
  216. """
  217. Rescale the image by the given factor. image = image * rescale_factor.
  218. Args:
  219. image (`np.ndarray`):
  220. Image to rescale.
  221. rescale_factor (`float`):
  222. The value to use for rescaling.
  223. data_format (`str` or `ChannelDimension`, *optional*):
  224. The channel dimension format for the output image. If unset, the channel dimension format of the input
  225. image is used. Can be one of:
  226. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  227. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  228. input_data_format (`str` or `ChannelDimension`, *optional*):
  229. The channel dimension format for the input image. If unset, is inferred from the input image. Can be
  230. one of:
  231. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  232. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  233. """
  234. return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
  235. @filter_out_non_signature_kwargs()
  236. def preprocess(
  237. self,
  238. images: ImageInput,
  239. do_resize: Optional[bool] = None,
  240. size: Optional[Dict[str, int]] = None,
  241. resample: PILImageResampling = None,
  242. do_center_crop: Optional[bool] = None,
  243. crop_size: Optional[Dict[str, int]] = None,
  244. do_rescale: Optional[bool] = None,
  245. rescale_factor: Optional[float] = None,
  246. do_normalize: Optional[bool] = None,
  247. image_mean: Optional[Union[float, List[float]]] = None,
  248. image_std: Optional[Union[float, List[float]]] = None,
  249. return_tensors: Optional[Union[TensorType, str]] = None,
  250. data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
  251. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  252. ) -> BatchFeature:
  253. """
  254. Prepares an image or batch of images for the model.
  255. Args:
  256. images (`ImageInput`):
  257. The image or batch of images to be prepared. Expects a single or batch of images with pixel values
  258. ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  259. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  260. Whether or not to resize the input. If `True`, will resize the input to the size specified by `size`.
  261. size (`Dict[str, int]`, *optional*, defaults to `self.size`):
  262. The size to resize the input to. Only has an effect if `do_resize` is set to `True`.
  263. resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
  264. The resampling filter to use when resizing the input. Only has an effect if `do_resize` is set to
  265. `True`.
  266. do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
  267. Whether or not to center crop the input. If `True`, will center crop the input to the size specified by
  268. `crop_size`.
  269. crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
  270. The size to center crop the input to. Only has an effect if `do_center_crop` is set to `True`.
  271. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  272. Whether or not to rescale the input. If `True`, will rescale the input by dividing it by
  273. `rescale_factor`.
  274. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  275. The factor to rescale the input by. Only has an effect if `do_rescale` is set to `True`.
  276. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  277. Whether or not to normalize the input. If `True`, will normalize the input by subtracting `image_mean`
  278. and dividing by `image_std`.
  279. image_mean (`Union[float, List[float]]`, *optional*, defaults to `self.image_mean`):
  280. The mean to subtract from the input when normalizing. Only has an effect if `do_normalize` is set to
  281. `True`.
  282. image_std (`Union[float, List[float]]`, *optional*, defaults to `self.image_std`):
  283. The standard deviation to divide the input by when normalizing. Only has an effect if `do_normalize` is
  284. set to `True`.
  285. return_tensors (`str` or `TensorType`, *optional*):
  286. The type of tensors to return. Can be one of:
  287. - Unset: Return a list of `np.ndarray`.
  288. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  289. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  290. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  291. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  292. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  293. The channel dimension format for the output image. Can be one of:
  294. - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  295. - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  296. - Unset: defaults to the channel dimension format of the input image.
  297. input_data_format (`ChannelDimension` or `str`, *optional*):
  298. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  299. from the input image. Can be one of:
  300. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  301. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  302. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  303. """
  304. do_resize = do_resize if do_resize is not None else self.do_resize
  305. size = size if size is not None else self.size
  306. resample = resample if resample is not None else self.resample
  307. do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
  308. crop_size = crop_size if crop_size is not None else self.crop_size
  309. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  310. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  311. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  312. image_mean = image_mean if image_mean is not None else self.image_mean
  313. image_std = image_std if image_std is not None else self.image_std
  314. images = make_list_of_images(images)
  315. if not valid_images(images):
  316. raise ValueError(
  317. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  318. "torch.Tensor, tf.Tensor or jax.ndarray."
  319. )
  320. validate_preprocess_arguments(
  321. do_rescale=do_rescale,
  322. rescale_factor=rescale_factor,
  323. do_normalize=do_normalize,
  324. image_mean=image_mean,
  325. image_std=image_std,
  326. do_center_crop=do_center_crop,
  327. crop_size=crop_size,
  328. do_resize=do_resize,
  329. size=size,
  330. resample=resample,
  331. )
  332. # All transformations expect numpy arrays
  333. images = [to_numpy_array(image) for image in images]
  334. if is_scaled_image(images[0]) and do_rescale:
  335. logger.warning_once(
  336. "It looks like you are trying to rescale already rescaled images. If the input"
  337. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  338. )
  339. if input_data_format is None:
  340. # We assume that all images have the same channel dimension format.
  341. input_data_format = infer_channel_dimension_format(images[0])
  342. if do_resize:
  343. images = [
  344. self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
  345. for image in images
  346. ]
  347. if do_center_crop:
  348. images = [
  349. self.center_crop(image, crop_size=crop_size, input_data_format=input_data_format) for image in images
  350. ]
  351. if do_rescale:
  352. images = [
  353. self.rescale(image, rescale_factor=rescale_factor, input_data_format=input_data_format)
  354. for image in images
  355. ]
  356. if do_normalize:
  357. images = [
  358. self.normalize(image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  359. for image in images
  360. ]
  361. images = [
  362. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  363. ]
  364. encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
  365. return encoded_inputs
  366. def post_process(self, outputs, target_sizes):
  367. """
  368. Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
  369. bottom_right_x, bottom_right_y) format.
  370. Args:
  371. outputs ([`OwlViTObjectDetectionOutput`]):
  372. Raw outputs of the model.
  373. target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
  374. Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
  375. image size (before any data augmentation). For visualization, this should be the image size after data
  376. augment, but before padding.
  377. Returns:
  378. `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  379. in the batch as predicted by the model.
  380. """
  381. # TODO: (amy) add support for other frameworks
  382. warnings.warn(
  383. "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
  384. " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
  385. FutureWarning,
  386. )
  387. logits, boxes = outputs.logits, outputs.pred_boxes
  388. if len(logits) != len(target_sizes):
  389. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
  390. if target_sizes.shape[1] != 2:
  391. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  392. probs = torch.max(logits, dim=-1)
  393. scores = torch.sigmoid(probs.values)
  394. labels = probs.indices
  395. # Convert to [x0, y0, x1, y1] format
  396. boxes = center_to_corners_format(boxes)
  397. # Convert from relative [0, 1] to absolute [0, height] coordinates
  398. img_h, img_w = target_sizes.unbind(1)
  399. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
  400. boxes = boxes * scale_fct[:, None, :]
  401. results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
  402. return results
  403. def post_process_object_detection(
  404. self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None
  405. ):
  406. """
  407. Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
  408. bottom_right_x, bottom_right_y) format.
  409. Args:
  410. outputs ([`OwlViTObjectDetectionOutput`]):
  411. Raw outputs of the model.
  412. threshold (`float`, *optional*):
  413. Score threshold to keep object detection predictions.
  414. target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
  415. Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
  416. `(height, width)` of each image in the batch. If unset, predictions will not be resized.
  417. Returns:
  418. `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  419. in the batch as predicted by the model.
  420. """
  421. # TODO: (amy) add support for other frameworks
  422. logits, boxes = outputs.logits, outputs.pred_boxes
  423. if target_sizes is not None:
  424. if len(logits) != len(target_sizes):
  425. raise ValueError(
  426. "Make sure that you pass in as many target sizes as the batch dimension of the logits"
  427. )
  428. probs = torch.max(logits, dim=-1)
  429. scores = torch.sigmoid(probs.values)
  430. labels = probs.indices
  431. # Convert to [x0, y0, x1, y1] format
  432. boxes = center_to_corners_format(boxes)
  433. # Convert from relative [0, 1] to absolute [0, height] coordinates
  434. if target_sizes is not None:
  435. if isinstance(target_sizes, List):
  436. img_h = torch.Tensor([i[0] for i in target_sizes])
  437. img_w = torch.Tensor([i[1] for i in target_sizes])
  438. else:
  439. img_h, img_w = target_sizes.unbind(1)
  440. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
  441. boxes = boxes * scale_fct[:, None, :]
  442. results = []
  443. for s, l, b in zip(scores, labels, boxes):
  444. score = s[s > threshold]
  445. label = l[s > threshold]
  446. box = b[s > threshold]
  447. results.append({"scores": score, "labels": label, "boxes": box})
  448. return results
  449. # TODO: (Amy) Make compatible with other frameworks
  450. def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None):
  451. """
  452. Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO
  453. api.
  454. Args:
  455. outputs ([`OwlViTImageGuidedObjectDetectionOutput`]):
  456. Raw outputs of the model.
  457. threshold (`float`, *optional*, defaults to 0.0):
  458. Minimum confidence threshold to use to filter out predicted boxes.
  459. nms_threshold (`float`, *optional*, defaults to 0.3):
  460. IoU threshold for non-maximum suppression of overlapping boxes.
  461. target_sizes (`torch.Tensor`, *optional*):
  462. Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in
  463. the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to
  464. None, predictions will not be unnormalized.
  465. Returns:
  466. `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
  467. in the batch as predicted by the model. All labels are set to None as
  468. `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection.
  469. """
  470. logits, target_boxes = outputs.logits, outputs.target_pred_boxes
  471. if target_sizes is not None and len(logits) != len(target_sizes):
  472. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
  473. if target_sizes is not None and target_sizes.shape[1] != 2:
  474. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  475. probs = torch.max(logits, dim=-1)
  476. scores = torch.sigmoid(probs.values)
  477. # Convert to [x0, y0, x1, y1] format
  478. target_boxes = center_to_corners_format(target_boxes)
  479. # Apply non-maximum suppression (NMS)
  480. if nms_threshold < 1.0:
  481. for idx in range(target_boxes.shape[0]):
  482. for i in torch.argsort(-scores[idx]):
  483. if not scores[idx][i]:
  484. continue
  485. ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0]
  486. ious[i] = -1.0 # Mask self-IoU.
  487. scores[idx][ious > nms_threshold] = 0.0
  488. # Convert from relative [0, 1] to absolute [0, height] coordinates
  489. if target_sizes is not None:
  490. if isinstance(target_sizes, List):
  491. img_h = torch.tensor([i[0] for i in target_sizes])
  492. img_w = torch.tensor([i[1] for i in target_sizes])
  493. else:
  494. img_h, img_w = target_sizes.unbind(1)
  495. scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device)
  496. target_boxes = target_boxes * scale_fct[:, None, :]
  497. # Compute box display alphas based on prediction scores
  498. results = []
  499. alphas = torch.zeros_like(scores)
  500. for idx in range(target_boxes.shape[0]):
  501. # Select scores for boxes matching the current query:
  502. query_scores = scores[idx]
  503. if not query_scores.nonzero().numel():
  504. continue
  505. # Apply threshold on scores before scaling
  506. query_scores[query_scores < threshold] = 0.0
  507. # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1.
  508. # All other boxes will either belong to a different query, or will not be shown.
  509. max_score = torch.max(query_scores) + 1e-6
  510. query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9)
  511. query_alphas = torch.clip(query_alphas, 0.0, 1.0)
  512. alphas[idx] = query_alphas
  513. mask = alphas[idx] > 0
  514. box_scores = alphas[idx][mask]
  515. boxes = target_boxes[idx][mask]
  516. results.append({"scores": box_scores, "labels": None, "boxes": boxes})
  517. return results