image_segmentation.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import warnings
  2. from typing import Any, Dict, List, Union
  3. import numpy as np
  4. from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging, requires_backends
  5. from .base import Pipeline, build_pipeline_init_args
  6. if is_vision_available():
  7. from PIL import Image
  8. from ..image_utils import load_image
  9. if is_torch_available():
  10. from ..models.auto.modeling_auto import (
  11. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
  12. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES,
  13. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
  14. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES,
  15. )
  16. logger = logging.get_logger(__name__)
  17. Prediction = Dict[str, Any]
  18. Predictions = List[Prediction]
  19. @add_end_docstrings(build_pipeline_init_args(has_image_processor=True))
  20. class ImageSegmentationPipeline(Pipeline):
  21. """
  22. Image segmentation pipeline using any `AutoModelForXXXSegmentation`. This pipeline predicts masks of objects and
  23. their classes.
  24. Example:
  25. ```python
  26. >>> from transformers import pipeline
  27. >>> segmenter = pipeline(model="facebook/detr-resnet-50-panoptic")
  28. >>> segments = segmenter("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
  29. >>> len(segments)
  30. 2
  31. >>> segments[0]["label"]
  32. 'bird'
  33. >>> segments[1]["label"]
  34. 'bird'
  35. >>> type(segments[0]["mask"]) # This is a black and white mask showing where is the bird on the original image.
  36. <class 'PIL.Image.Image'>
  37. >>> segments[0]["mask"].size
  38. (768, 512)
  39. ```
  40. This image segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  41. `"image-segmentation"`.
  42. See the list of available models on
  43. [huggingface.co/models](https://huggingface.co/models?filter=image-segmentation).
  44. """
  45. def __init__(self, *args, **kwargs):
  46. super().__init__(*args, **kwargs)
  47. if self.framework == "tf":
  48. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  49. requires_backends(self, "vision")
  50. mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES.copy()
  51. mapping.update(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES)
  52. mapping.update(MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES)
  53. mapping.update(MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES)
  54. self.check_model_type(mapping)
  55. def _sanitize_parameters(self, **kwargs):
  56. preprocess_kwargs = {}
  57. postprocess_kwargs = {}
  58. if "subtask" in kwargs:
  59. postprocess_kwargs["subtask"] = kwargs["subtask"]
  60. preprocess_kwargs["subtask"] = kwargs["subtask"]
  61. if "threshold" in kwargs:
  62. postprocess_kwargs["threshold"] = kwargs["threshold"]
  63. if "mask_threshold" in kwargs:
  64. postprocess_kwargs["mask_threshold"] = kwargs["mask_threshold"]
  65. if "overlap_mask_area_threshold" in kwargs:
  66. postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
  67. if "timeout" in kwargs:
  68. warnings.warn(
  69. "The `timeout` argument is deprecated and will be removed in version 5 of Transformers", FutureWarning
  70. )
  71. preprocess_kwargs["timeout"] = kwargs["timeout"]
  72. return preprocess_kwargs, {}, postprocess_kwargs
  73. def __call__(self, inputs=None, **kwargs) -> Union[Predictions, List[Prediction]]:
  74. """
  75. Perform segmentation (detect masks & classes) in the image(s) passed as inputs.
  76. Args:
  77. inputs (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
  78. The pipeline handles three types of images:
  79. - A string containing an HTTP(S) link pointing to an image
  80. - A string containing a local path to an image
  81. - An image loaded in PIL directly
  82. The pipeline accepts either a single image or a batch of images. Images in a batch must all be in the
  83. same format: all as HTTP(S) links, all as local paths, or all as PIL images.
  84. subtask (`str`, *optional*):
  85. Segmentation task to be performed, choose [`semantic`, `instance` and `panoptic`] depending on model
  86. capabilities. If not set, the pipeline will attempt tp resolve in the following order:
  87. `panoptic`, `instance`, `semantic`.
  88. threshold (`float`, *optional*, defaults to 0.9):
  89. Probability threshold to filter out predicted masks.
  90. mask_threshold (`float`, *optional*, defaults to 0.5):
  91. Threshold to use when turning the predicted masks into binary values.
  92. overlap_mask_area_threshold (`float`, *optional*, defaults to 0.5):
  93. Mask overlap threshold to eliminate small, disconnected segments.
  94. Return:
  95. A dictionary or a list of dictionaries containing the result. If the input is a single image, will return a
  96. list of dictionaries, if the input is a list of several images, will return a list of list of dictionaries
  97. corresponding to each image.
  98. The dictionaries contain the mask, label and score (where applicable) of each detected object and contains
  99. the following keys:
  100. - **label** (`str`) -- The class label identified by the model.
  101. - **mask** (`PIL.Image`) -- A binary mask of the detected object as a Pil Image of shape (width, height) of
  102. the original image. Returns a mask filled with zeros if no object is found.
  103. - **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of the
  104. "object" described by the label and the mask.
  105. """
  106. # After deprecation of this is completed, remove the default `None` value for `images`
  107. if "images" in kwargs:
  108. inputs = kwargs.pop("images")
  109. if inputs is None:
  110. raise ValueError("Cannot call the image-classification pipeline without an inputs argument!")
  111. return super().__call__(inputs, **kwargs)
  112. def preprocess(self, image, subtask=None, timeout=None):
  113. image = load_image(image, timeout=timeout)
  114. target_size = [(image.height, image.width)]
  115. if self.model.config.__class__.__name__ == "OneFormerConfig":
  116. if subtask is None:
  117. kwargs = {}
  118. else:
  119. kwargs = {"task_inputs": [subtask]}
  120. inputs = self.image_processor(images=[image], return_tensors="pt", **kwargs)
  121. if self.framework == "pt":
  122. inputs = inputs.to(self.torch_dtype)
  123. inputs["task_inputs"] = self.tokenizer(
  124. inputs["task_inputs"],
  125. padding="max_length",
  126. max_length=self.model.config.task_seq_len,
  127. return_tensors=self.framework,
  128. )["input_ids"]
  129. else:
  130. inputs = self.image_processor(images=[image], return_tensors="pt")
  131. if self.framework == "pt":
  132. inputs = inputs.to(self.torch_dtype)
  133. inputs["target_size"] = target_size
  134. return inputs
  135. def _forward(self, model_inputs):
  136. target_size = model_inputs.pop("target_size")
  137. model_outputs = self.model(**model_inputs)
  138. model_outputs["target_size"] = target_size
  139. return model_outputs
  140. def postprocess(
  141. self, model_outputs, subtask=None, threshold=0.9, mask_threshold=0.5, overlap_mask_area_threshold=0.5
  142. ):
  143. fn = None
  144. if subtask in {"panoptic", None} and hasattr(self.image_processor, "post_process_panoptic_segmentation"):
  145. fn = self.image_processor.post_process_panoptic_segmentation
  146. elif subtask in {"instance", None} and hasattr(self.image_processor, "post_process_instance_segmentation"):
  147. fn = self.image_processor.post_process_instance_segmentation
  148. if fn is not None:
  149. outputs = fn(
  150. model_outputs,
  151. threshold=threshold,
  152. mask_threshold=mask_threshold,
  153. overlap_mask_area_threshold=overlap_mask_area_threshold,
  154. target_sizes=model_outputs["target_size"],
  155. )[0]
  156. annotation = []
  157. segmentation = outputs["segmentation"]
  158. for segment in outputs["segments_info"]:
  159. mask = (segmentation == segment["id"]) * 255
  160. mask = Image.fromarray(mask.numpy().astype(np.uint8), mode="L")
  161. label = self.model.config.id2label[segment["label_id"]]
  162. score = segment["score"]
  163. annotation.append({"score": score, "label": label, "mask": mask})
  164. elif subtask in {"semantic", None} and hasattr(self.image_processor, "post_process_semantic_segmentation"):
  165. outputs = self.image_processor.post_process_semantic_segmentation(
  166. model_outputs, target_sizes=model_outputs["target_size"]
  167. )[0]
  168. annotation = []
  169. segmentation = outputs.numpy()
  170. labels = np.unique(segmentation)
  171. for label in labels:
  172. mask = (segmentation == label) * 255
  173. mask = Image.fromarray(mask.astype(np.uint8), mode="L")
  174. label = self.model.config.id2label[label]
  175. annotation.append({"score": None, "label": label, "mask": mask})
  176. else:
  177. raise ValueError(f"Subtask {subtask} is not supported for model {type(self.model)}")
  178. return annotation