zero_shot_object_detection.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from typing import Any, Dict, List, Union
  2. from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
  3. from .base import ChunkPipeline, build_pipeline_init_args
  4. if is_vision_available():
  5. from PIL import Image
  6. from ..image_utils import load_image
  7. if is_torch_available():
  8. import torch
  9. from transformers.modeling_outputs import BaseModelOutput
  10. from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
  11. logger = logging.get_logger(__name__)
  12. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  13. class ZeroShotObjectDetectionPipeline(ChunkPipeline):
  14. """
  15. Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
  16. objects when you provide an image and a set of `candidate_labels`.
  17. Example:
  18. ```python
  19. >>> from transformers import pipeline
  20. >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
  21. >>> detector(
  22. ... "http://images.cocodataset.org/val2017/000000039769.jpg",
  23. ... candidate_labels=["cat", "couch"],
  24. ... )
  25. [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]
  26. >>> detector(
  27. ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
  28. ... candidate_labels=["head", "bird"],
  29. ... )
  30. [{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]
  31. ```
  32. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  33. This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  34. `"zero-shot-object-detection"`.
  35. See the list of available models on
  36. [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
  37. """
  38. def __init__(self, **kwargs):
  39. super().__init__(**kwargs)
  40. if self.framework == "tf":
  41. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  42. requires_backends(self, "vision")
  43. self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)
  44. def __call__(
  45. self,
  46. image: Union[str, "Image.Image", List[Dict[str, Any]]],
  47. candidate_labels: Union[str, List[str]] = None,
  48. **kwargs,
  49. ):
  50. """
  51. Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
  52. Args:
  53. image (`str`, `PIL.Image` or `List[Dict[str, Any]]`):
  54. The pipeline handles three types of images:
  55. - A string containing an http url pointing to an image
  56. - A string containing a local path to an image
  57. - An image loaded in PIL directly
  58. You can use this parameter to send directly a list of images, or a dataset or a generator like so:
  59. ```python
  60. >>> from transformers import pipeline
  61. >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
  62. >>> detector(
  63. ... [
  64. ... {
  65. ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg",
  66. ... "candidate_labels": ["cat", "couch"],
  67. ... },
  68. ... {
  69. ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg",
  70. ... "candidate_labels": ["cat", "couch"],
  71. ... },
  72. ... ]
  73. ... )
  74. [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.25, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}], [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]]
  75. ```
  76. candidate_labels (`str` or `List[str]` or `List[List[str]]`):
  77. What the model should recognize in the image.
  78. threshold (`float`, *optional*, defaults to 0.1):
  79. The probability necessary to make a prediction.
  80. top_k (`int`, *optional*, defaults to None):
  81. The number of top predictions that will be returned by the pipeline. If the provided number is `None`
  82. or higher than the number of predictions available, it will default to the number of predictions.
  83. timeout (`float`, *optional*, defaults to None):
  84. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  85. the call may block forever.
  86. Return:
  87. A list of lists containing prediction results, one list per input image. Each list contains dictionaries
  88. with the following keys:
  89. - **label** (`str`) -- Text query corresponding to the found object.
  90. - **score** (`float`) -- Score corresponding to the object (between 0 and 1).
  91. - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
  92. dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
  93. """
  94. if "text_queries" in kwargs:
  95. candidate_labels = kwargs.pop("text_queries")
  96. if isinstance(image, (str, Image.Image)):
  97. inputs = {"image": image, "candidate_labels": candidate_labels}
  98. else:
  99. inputs = image
  100. results = super().__call__(inputs, **kwargs)
  101. return results
  102. def _sanitize_parameters(self, **kwargs):
  103. preprocess_params = {}
  104. if "timeout" in kwargs:
  105. preprocess_params["timeout"] = kwargs["timeout"]
  106. postprocess_params = {}
  107. if "threshold" in kwargs:
  108. postprocess_params["threshold"] = kwargs["threshold"]
  109. if "top_k" in kwargs:
  110. postprocess_params["top_k"] = kwargs["top_k"]
  111. return preprocess_params, {}, postprocess_params
  112. def preprocess(self, inputs, timeout=None):
  113. image = load_image(inputs["image"], timeout=timeout)
  114. candidate_labels = inputs["candidate_labels"]
  115. if isinstance(candidate_labels, str):
  116. candidate_labels = candidate_labels.split(",")
  117. target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32)
  118. for i, candidate_label in enumerate(candidate_labels):
  119. text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
  120. image_features = self.image_processor(image, return_tensors=self.framework)
  121. if self.framework == "pt":
  122. image_features = image_features.to(self.torch_dtype)
  123. yield {
  124. "is_last": i == len(candidate_labels) - 1,
  125. "target_size": target_size,
  126. "candidate_label": candidate_label,
  127. **text_inputs,
  128. **image_features,
  129. }
  130. def _forward(self, model_inputs):
  131. target_size = model_inputs.pop("target_size")
  132. candidate_label = model_inputs.pop("candidate_label")
  133. is_last = model_inputs.pop("is_last")
  134. outputs = self.model(**model_inputs)
  135. model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs}
  136. return model_outputs
  137. def postprocess(self, model_outputs, threshold=0.1, top_k=None):
  138. results = []
  139. for model_output in model_outputs:
  140. label = model_output["candidate_label"]
  141. model_output = BaseModelOutput(model_output)
  142. outputs = self.image_processor.post_process_object_detection(
  143. outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
  144. )[0]
  145. for index in outputs["scores"].nonzero():
  146. score = outputs["scores"][index].item()
  147. box = self._get_bounding_box(outputs["boxes"][index][0])
  148. result = {"score": score, "label": label, "box": box}
  149. results.append(result)
  150. results = sorted(results, key=lambda x: x["score"], reverse=True)
  151. if top_k:
  152. results = results[:top_k]
  153. return results
  154. def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
  155. """
  156. Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
  157. Args:
  158. box (`torch.Tensor`): Tensor containing the coordinates in corners format.
  159. Returns:
  160. bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
  161. """
  162. if self.framework != "pt":
  163. raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
  164. xmin, ymin, xmax, ymax = box.int().tolist()
  165. bbox = {
  166. "xmin": xmin,
  167. "ymin": ymin,
  168. "xmax": xmax,
  169. "ymax": ymax,
  170. }
  171. return bbox