zero_shot_audio_classification.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections import UserDict
  16. from typing import Union
  17. import numpy as np
  18. import requests
  19. from ..utils import (
  20. add_end_docstrings,
  21. logging,
  22. )
  23. from .audio_classification import ffmpeg_read
  24. from .base import Pipeline, build_pipeline_init_args
  25. logger = logging.get_logger(__name__)
  26. @add_end_docstrings(build_pipeline_init_args(has_feature_extractor=True, has_tokenizer=True))
  27. class ZeroShotAudioClassificationPipeline(Pipeline):
  28. """
  29. Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
  30. provide an audio and a set of `candidate_labels`.
  31. <Tip warning={true}>
  32. The default `hypothesis_template` is : `"This is a sound of {}."`. Make sure you update it for your usage.
  33. </Tip>
  34. Example:
  35. ```python
  36. >>> from transformers import pipeline
  37. >>> from datasets import load_dataset
  38. >>> dataset = load_dataset("ashraq/esc50")
  39. >>> audio = next(iter(dataset["train"]["audio"]))["array"]
  40. >>> classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
  41. >>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
  42. [{'score': 0.9996, 'label': 'Sound of a dog'}, {'score': 0.0004, 'label': 'Sound of vaccum cleaner'}]
  43. ```
  44. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) This audio
  45. classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  46. `"zero-shot-audio-classification"`. See the list of available models on
  47. [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification).
  48. """
  49. def __init__(self, **kwargs):
  50. super().__init__(**kwargs)
  51. if self.framework != "pt":
  52. raise ValueError(f"The {self.__class__} is only available in PyTorch.")
  53. # No specific FOR_XXX available yet
  54. def __call__(self, audios: Union[np.ndarray, bytes, str], **kwargs):
  55. """
  56. Assign labels to the audio(s) passed as inputs.
  57. Args:
  58. audios (`str`, `List[str]`, `np.array` or `List[np.array]`):
  59. The pipeline handles three types of inputs:
  60. - A string containing a http link pointing to an audio
  61. - A string containing a local path to an audio
  62. - An audio loaded in numpy
  63. candidate_labels (`List[str]`):
  64. The candidate labels for this audio. They will be formatted using *hypothesis_template*.
  65. hypothesis_template (`str`, *optional*, defaults to `"This is a sound of {}"`):
  66. The format used in conjunction with *candidate_labels* to attempt the audio classification by
  67. replacing the placeholder with the candidate_labels. Pass "{}" if *candidate_labels* are
  68. already formatted.
  69. Return:
  70. A list of dictionaries containing one entry per proposed label. Each dictionary contains the
  71. following keys:
  72. - **label** (`str`) -- One of the suggested *candidate_labels*.
  73. - **score** (`float`) -- The score attributed by the model to that label. It is a value between
  74. 0 and 1, computed as the `softmax` of `logits_per_audio`.
  75. """
  76. return super().__call__(audios, **kwargs)
  77. def _sanitize_parameters(self, **kwargs):
  78. preprocess_params = {}
  79. if "candidate_labels" in kwargs:
  80. preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
  81. if "hypothesis_template" in kwargs:
  82. preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
  83. return preprocess_params, {}, {}
  84. def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is a sound of {}."):
  85. if isinstance(audio, str):
  86. if audio.startswith("http://") or audio.startswith("https://"):
  87. # We need to actually check for a real protocol, otherwise it's impossible to use a local file
  88. # like http_huggingface_co.png
  89. audio = requests.get(audio).content
  90. else:
  91. with open(audio, "rb") as f:
  92. audio = f.read()
  93. if isinstance(audio, bytes):
  94. audio = ffmpeg_read(audio, self.feature_extractor.sampling_rate)
  95. if not isinstance(audio, np.ndarray):
  96. raise TypeError("We expect a numpy ndarray as input")
  97. if len(audio.shape) != 1:
  98. raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
  99. inputs = self.feature_extractor(
  100. [audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
  101. )
  102. if self.framework == "pt":
  103. inputs = inputs.to(self.torch_dtype)
  104. inputs["candidate_labels"] = candidate_labels
  105. sequences = [hypothesis_template.format(x) for x in candidate_labels]
  106. text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
  107. inputs["text_inputs"] = [text_inputs]
  108. return inputs
  109. def _forward(self, model_inputs):
  110. candidate_labels = model_inputs.pop("candidate_labels")
  111. text_inputs = model_inputs.pop("text_inputs")
  112. if isinstance(text_inputs[0], UserDict):
  113. text_inputs = text_inputs[0]
  114. else:
  115. # Batching case.
  116. text_inputs = text_inputs[0][0]
  117. outputs = self.model(**text_inputs, **model_inputs)
  118. model_outputs = {
  119. "candidate_labels": candidate_labels,
  120. "logits": outputs.logits_per_audio,
  121. }
  122. return model_outputs
  123. def postprocess(self, model_outputs):
  124. candidate_labels = model_outputs.pop("candidate_labels")
  125. logits = model_outputs["logits"][0]
  126. if self.framework == "pt":
  127. probs = logits.softmax(dim=0)
  128. scores = probs.tolist()
  129. else:
  130. raise ValueError("`tf` framework not supported.")
  131. result = [
  132. {"score": score, "label": candidate_label}
  133. for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
  134. ]
  135. return result