| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- from typing import Any, Dict, List, Union
- from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
- from .base import ChunkPipeline, build_pipeline_init_args
- if is_vision_available():
- from PIL import Image
- from ..image_utils import load_image
- if is_torch_available():
- import torch
- from transformers.modeling_outputs import BaseModelOutput
- from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
- logger = logging.get_logger(__name__)
- @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
- class ZeroShotObjectDetectionPipeline(ChunkPipeline):
- """
- Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of
- objects when you provide an image and a set of `candidate_labels`.
- Example:
- ```python
- >>> from transformers import pipeline
- >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
- >>> detector(
- ... "http://images.cocodataset.org/val2017/000000039769.jpg",
- ... candidate_labels=["cat", "couch"],
- ... )
- [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]
- >>> detector(
- ... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png",
- ... candidate_labels=["head", "bird"],
- ... )
- [{'score': 0.119, 'label': 'bird', 'box': {'xmin': 71, 'ymin': 170, 'xmax': 410, 'ymax': 508}}]
- ```
- Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
- This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier:
- `"zero-shot-object-detection"`.
- See the list of available models on
- [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection).
- """
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- if self.framework == "tf":
- raise ValueError(f"The {self.__class__} is only available in PyTorch.")
- requires_backends(self, "vision")
- self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES)
- def __call__(
- self,
- image: Union[str, "Image.Image", List[Dict[str, Any]]],
- candidate_labels: Union[str, List[str]] = None,
- **kwargs,
- ):
- """
- Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
- Args:
- image (`str`, `PIL.Image` or `List[Dict[str, Any]]`):
- The pipeline handles three types of images:
- - A string containing an http url pointing to an image
- - A string containing a local path to an image
- - An image loaded in PIL directly
- You can use this parameter to send directly a list of images, or a dataset or a generator like so:
- ```python
- >>> from transformers import pipeline
- >>> detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection")
- >>> detector(
- ... [
- ... {
- ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg",
- ... "candidate_labels": ["cat", "couch"],
- ... },
- ... {
- ... "image": "http://images.cocodataset.org/val2017/000000039769.jpg",
- ... "candidate_labels": ["cat", "couch"],
- ... },
- ... ]
- ... )
- [[{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.25, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}], [{'score': 0.287, 'label': 'cat', 'box': {'xmin': 324, 'ymin': 20, 'xmax': 640, 'ymax': 373}}, {'score': 0.254, 'label': 'cat', 'box': {'xmin': 1, 'ymin': 55, 'xmax': 315, 'ymax': 472}}, {'score': 0.121, 'label': 'couch', 'box': {'xmin': 4, 'ymin': 0, 'xmax': 642, 'ymax': 476}}]]
- ```
- candidate_labels (`str` or `List[str]` or `List[List[str]]`):
- What the model should recognize in the image.
- threshold (`float`, *optional*, defaults to 0.1):
- The probability necessary to make a prediction.
- top_k (`int`, *optional*, defaults to None):
- The number of top predictions that will be returned by the pipeline. If the provided number is `None`
- or higher than the number of predictions available, it will default to the number of predictions.
- timeout (`float`, *optional*, defaults to None):
- The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
- the call may block forever.
- Return:
- A list of lists containing prediction results, one list per input image. Each list contains dictionaries
- with the following keys:
- - **label** (`str`) -- Text query corresponding to the found object.
- - **score** (`float`) -- Score corresponding to the object (between 0 and 1).
- - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a
- dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys.
- """
- if "text_queries" in kwargs:
- candidate_labels = kwargs.pop("text_queries")
- if isinstance(image, (str, Image.Image)):
- inputs = {"image": image, "candidate_labels": candidate_labels}
- else:
- inputs = image
- results = super().__call__(inputs, **kwargs)
- return results
- def _sanitize_parameters(self, **kwargs):
- preprocess_params = {}
- if "timeout" in kwargs:
- preprocess_params["timeout"] = kwargs["timeout"]
- postprocess_params = {}
- if "threshold" in kwargs:
- postprocess_params["threshold"] = kwargs["threshold"]
- if "top_k" in kwargs:
- postprocess_params["top_k"] = kwargs["top_k"]
- return preprocess_params, {}, postprocess_params
- def preprocess(self, inputs, timeout=None):
- image = load_image(inputs["image"], timeout=timeout)
- candidate_labels = inputs["candidate_labels"]
- if isinstance(candidate_labels, str):
- candidate_labels = candidate_labels.split(",")
- target_size = torch.tensor([[image.height, image.width]], dtype=torch.int32)
- for i, candidate_label in enumerate(candidate_labels):
- text_inputs = self.tokenizer(candidate_label, return_tensors=self.framework)
- image_features = self.image_processor(image, return_tensors=self.framework)
- if self.framework == "pt":
- image_features = image_features.to(self.torch_dtype)
- yield {
- "is_last": i == len(candidate_labels) - 1,
- "target_size": target_size,
- "candidate_label": candidate_label,
- **text_inputs,
- **image_features,
- }
- def _forward(self, model_inputs):
- target_size = model_inputs.pop("target_size")
- candidate_label = model_inputs.pop("candidate_label")
- is_last = model_inputs.pop("is_last")
- outputs = self.model(**model_inputs)
- model_outputs = {"target_size": target_size, "candidate_label": candidate_label, "is_last": is_last, **outputs}
- return model_outputs
- def postprocess(self, model_outputs, threshold=0.1, top_k=None):
- results = []
- for model_output in model_outputs:
- label = model_output["candidate_label"]
- model_output = BaseModelOutput(model_output)
- outputs = self.image_processor.post_process_object_detection(
- outputs=model_output, threshold=threshold, target_sizes=model_output["target_size"]
- )[0]
- for index in outputs["scores"].nonzero():
- score = outputs["scores"][index].item()
- box = self._get_bounding_box(outputs["boxes"][index][0])
- result = {"score": score, "label": label, "box": box}
- results.append(result)
- results = sorted(results, key=lambda x: x["score"], reverse=True)
- if top_k:
- results = results[:top_k]
- return results
- def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]:
- """
- Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... }
- Args:
- box (`torch.Tensor`): Tensor containing the coordinates in corners format.
- Returns:
- bbox (`Dict[str, int]`): Dict containing the coordinates in corners format.
- """
- if self.framework != "pt":
- raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.")
- xmin, ymin, xmax, ymax = box.int().tolist()
- bbox = {
- "xmin": xmin,
- "ymin": ymin,
- "xmax": xmax,
- "ymax": ymax,
- }
- return bbox
|