processing_clvp.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for CLVP
  17. """
  18. from ...processing_utils import ProcessorMixin
  19. class ClvpProcessor(ProcessorMixin):
  20. r"""
  21. Constructs a CLVP processor which wraps a CLVP Feature Extractor and a CLVP Tokenizer into a single processor.
  22. [`ClvpProcessor`] offers all the functionalities of [`ClvpFeatureExtractor`] and [`ClvpTokenizer`]. See the
  23. [`~ClvpProcessor.__call__`], [`~ClvpProcessor.decode`] and [`~ClvpProcessor.batch_decode`] for more information.
  24. Args:
  25. feature_extractor (`ClvpFeatureExtractor`):
  26. An instance of [`ClvpFeatureExtractor`]. The feature extractor is a required input.
  27. tokenizer (`ClvpTokenizer`):
  28. An instance of [`ClvpTokenizer`]. The tokenizer is a required input.
  29. """
  30. feature_extractor_class = "ClvpFeatureExtractor"
  31. tokenizer_class = "ClvpTokenizer"
  32. model_input_names = [
  33. "input_ids",
  34. "input_features",
  35. "attention_mask",
  36. ]
  37. def __init__(self, feature_extractor, tokenizer):
  38. super().__init__(feature_extractor, tokenizer)
  39. def __call__(self, *args, **kwargs):
  40. """
  41. Forwards the `audio` and `sampling_rate` arguments to [`~ClvpFeatureExtractor.__call__`] and the `text`
  42. argument to [`~ClvpTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
  43. information.
  44. """
  45. raw_speech = kwargs.pop("raw_speech", None)
  46. sampling_rate = kwargs.pop("sampling_rate", None)
  47. text = kwargs.pop("text", None)
  48. if raw_speech is None and text is None:
  49. raise ValueError("You need to specify either an `raw_speech` or `text` input to process.")
  50. if raw_speech is not None:
  51. inputs = self.feature_extractor(raw_speech, sampling_rate=sampling_rate, **kwargs)
  52. if text is not None:
  53. encodings = self.tokenizer(text, **kwargs)
  54. if text is None:
  55. return inputs
  56. elif raw_speech is None:
  57. return encodings
  58. else:
  59. inputs["input_ids"] = encodings["input_ids"]
  60. inputs["attention_mask"] = encodings["attention_mask"]
  61. return inputs
  62. # Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.batch_decode with Whisper->Clvp
  63. def batch_decode(self, *args, **kwargs):
  64. """
  65. This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
  66. refer to the docstring of this method for more information.
  67. """
  68. return self.tokenizer.batch_decode(*args, **kwargs)
  69. # Copied from transformers.models.whisper.processing_whisper.WhisperProcessor.decode with Whisper->Clvp
  70. def decode(self, *args, **kwargs):
  71. """
  72. This method forwards all its arguments to ClvpTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
  73. the docstring of this method for more information.
  74. """
  75. return self.tokenizer.decode(*args, **kwargs)