tokenization_vits.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # coding=utf-8
  2. # Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization class for VITS."""
  16. import json
  17. import os
  18. import re
  19. from typing import Any, Dict, List, Optional, Tuple, Union
  20. from ...tokenization_utils import PreTrainedTokenizer
  21. from ...utils import is_phonemizer_available, is_uroman_available, logging
  22. if is_phonemizer_available():
  23. import phonemizer
  24. if is_uroman_available():
  25. import uroman as ur
  26. logger = logging.get_logger(__name__)
  27. VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
  28. def has_non_roman_characters(input_string):
  29. # Find any character outside the ASCII range
  30. non_roman_pattern = re.compile(r"[^\x00-\x7F]")
  31. # Search the input string for non-Roman characters
  32. match = non_roman_pattern.search(input_string)
  33. has_non_roman = match is not None
  34. return has_non_roman
  35. class VitsTokenizer(PreTrainedTokenizer):
  36. """
  37. Construct a VITS tokenizer. Also supports MMS-TTS.
  38. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  39. this superclass for more information regarding those methods.
  40. Args:
  41. vocab_file (`str`):
  42. Path to the vocabulary file.
  43. language (`str`, *optional*):
  44. Language identifier.
  45. add_blank (`bool`, *optional*, defaults to `True`):
  46. Whether to insert token id 0 in between the other tokens.
  47. normalize (`bool`, *optional*, defaults to `True`):
  48. Whether to normalize the input text by removing all casing and punctuation.
  49. phonemize (`bool`, *optional*, defaults to `True`):
  50. Whether to convert the input text into phonemes.
  51. is_uroman (`bool`, *optional*, defaults to `False`):
  52. Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing.
  53. """
  54. vocab_files_names = VOCAB_FILES_NAMES
  55. model_input_names = ["input_ids", "attention_mask"]
  56. def __init__(
  57. self,
  58. vocab_file,
  59. pad_token="<pad>",
  60. unk_token="<unk>",
  61. language=None,
  62. add_blank=True,
  63. normalize=True,
  64. phonemize=True,
  65. is_uroman=False,
  66. **kwargs,
  67. ) -> None:
  68. with open(vocab_file, encoding="utf-8") as vocab_handle:
  69. self.encoder = json.load(vocab_handle)
  70. self.decoder = {v: k for k, v in self.encoder.items()}
  71. self.language = language
  72. self.add_blank = add_blank
  73. self.normalize = normalize
  74. self.phonemize = phonemize
  75. self.is_uroman = is_uroman
  76. super().__init__(
  77. pad_token=pad_token,
  78. unk_token=unk_token,
  79. language=language,
  80. add_blank=add_blank,
  81. normalize=normalize,
  82. phonemize=phonemize,
  83. is_uroman=is_uroman,
  84. **kwargs,
  85. )
  86. @property
  87. def vocab_size(self):
  88. return len(self.encoder)
  89. def get_vocab(self):
  90. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
  91. vocab.update(self.added_tokens_encoder)
  92. return vocab
  93. def normalize_text(self, input_string):
  94. """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
  95. all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
  96. filtered_text = ""
  97. i = 0
  98. while i < len(input_string):
  99. found_match = False
  100. for word in all_vocabulary:
  101. if input_string[i : i + len(word)] == word:
  102. filtered_text += word
  103. i += len(word)
  104. found_match = True
  105. break
  106. if not found_match:
  107. filtered_text += input_string[i].lower()
  108. i += 1
  109. return filtered_text
  110. def _preprocess_char(self, text):
  111. """Special treatment of characters in certain languages"""
  112. if self.language == "ron":
  113. text = text.replace("ț", "ţ")
  114. return text
  115. def prepare_for_tokenization(
  116. self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs
  117. ) -> Tuple[str, Dict[str, Any]]:
  118. """
  119. Performs any necessary transformations before tokenization.
  120. This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
  121. `kwargs` at the end of the encoding process to be sure all the arguments have been used.
  122. Args:
  123. text (`str`):
  124. The text to prepare.
  125. is_split_into_words (`bool`, *optional*, defaults to `False`):
  126. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  127. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  128. which it will tokenize.
  129. normalize (`bool`, *optional*, defaults to `None`):
  130. Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is
  131. trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input
  132. text consists only of lower-case characters.
  133. kwargs (`Dict[str, Any]`, *optional*):
  134. Keyword arguments to use for the tokenization.
  135. Returns:
  136. `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
  137. """
  138. normalize = normalize if normalize is not None else self.normalize
  139. if normalize:
  140. # normalise for casing
  141. text = self.normalize_text(text)
  142. filtered_text = self._preprocess_char(text)
  143. if has_non_roman_characters(filtered_text) and self.is_uroman:
  144. if not is_uroman_available():
  145. logger.warning(
  146. "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing "
  147. "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` "
  148. "Note `uroman` requires python version >= 3.10"
  149. "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman"
  150. )
  151. else:
  152. uroman = ur.Uroman()
  153. filtered_text = uroman.romanize_string(filtered_text)
  154. if self.phonemize:
  155. if not is_phonemizer_available():
  156. raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
  157. filtered_text = phonemizer.phonemize(
  158. filtered_text,
  159. language="en-us",
  160. backend="espeak",
  161. strip=True,
  162. preserve_punctuation=True,
  163. with_stress=True,
  164. )
  165. filtered_text = re.sub(r"\s+", " ", filtered_text)
  166. elif normalize:
  167. # strip any chars outside of the vocab (punctuation)
  168. filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
  169. return filtered_text, kwargs
  170. def _tokenize(self, text: str) -> List[str]:
  171. """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
  172. tokens = list(text)
  173. if self.add_blank:
  174. interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1)
  175. interspersed[1::2] = tokens
  176. tokens = interspersed
  177. return tokens
  178. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  179. if self.add_blank and len(tokens) > 1:
  180. tokens = tokens[1::2]
  181. return "".join(tokens)
  182. def _convert_token_to_id(self, token):
  183. """Converts a token (str) in an id using the vocab."""
  184. return self.encoder.get(token, self.encoder.get(self.unk_token))
  185. def _convert_id_to_token(self, index):
  186. """Converts an index (integer) in a token (str) using the vocab."""
  187. return self.decoder.get(index)
  188. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]:
  189. if not os.path.isdir(save_directory):
  190. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  191. return
  192. vocab_file = os.path.join(
  193. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  194. )
  195. with open(vocab_file, "w", encoding="utf-8") as f:
  196. f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  197. return (vocab_file,)