tokenization_jukebox.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # coding=utf-8
  2. # Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for OpenAI Jukebox."""
  16. import json
  17. import os
  18. import re
  19. import unicodedata
  20. from json.encoder import INFINITY
  21. from typing import Any, Dict, List, Optional, Tuple, Union
  22. import numpy as np
  23. import regex
  24. from ....tokenization_utils import AddedToken, PreTrainedTokenizer
  25. from ....tokenization_utils_base import BatchEncoding
  26. from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
  27. from ....utils.generic import _is_jax, _is_numpy
  28. logger = logging.get_logger(__name__)
  29. VOCAB_FILES_NAMES = {
  30. "artists_file": "artists.json",
  31. "lyrics_file": "lyrics.json",
  32. "genres_file": "genres.json",
  33. }
  34. class JukeboxTokenizer(PreTrainedTokenizer):
  35. """
  36. Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs :
  37. - Artists, unique ids are associated to each artist from the provided dictionary.
  38. - Genres, unique ids are associated to each genre from the provided dictionary.
  39. - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the
  40. vocabulary.
  41. This tokenizer does not require training. It should be able to process a different number of inputs:
  42. as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.:
  43. Depending on the number of genres on which the model should be conditioned (`n_genres`).
  44. ```python
  45. >>> from transformers import JukeboxTokenizer
  46. >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
  47. >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
  48. [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49,
  49. 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])]
  50. ```
  51. You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
  52. call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
  53. <Tip>
  54. If nothing is provided, the genres and the artist will either be selected randomly or set to None
  55. </Tip>
  56. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to:
  57. this superclass for more information regarding those methods.
  58. However the code does not allow that and only supports composing from various genres.
  59. Args:
  60. artists_file (`str`):
  61. Path to the vocabulary file which contains a mapping between artists and ids. The default file supports
  62. both "v2" and "v3"
  63. genres_file (`str`):
  64. Path to the vocabulary file which contain a mapping between genres and ids.
  65. lyrics_file (`str`):
  66. Path to the vocabulary file which contains the accepted characters for the lyrics tokenization.
  67. version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) :
  68. List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of
  69. `v2`.
  70. n_genres (`int`, `optional`, defaults to 1):
  71. Maximum number of genres to use for composition.
  72. max_n_lyric_tokens (`int`, `optional`, defaults to 512):
  73. Maximum number of lyric tokens to keep.
  74. unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
  75. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  76. token instead.
  77. """
  78. vocab_files_names = VOCAB_FILES_NAMES
  79. model_input_names = ["input_ids", "attention_mask"]
  80. def __init__(
  81. self,
  82. artists_file,
  83. genres_file,
  84. lyrics_file,
  85. version=["v3", "v2", "v2"],
  86. max_n_lyric_tokens=512,
  87. n_genres=5,
  88. unk_token="<|endoftext|>",
  89. **kwargs,
  90. ):
  91. unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
  92. self.version = version
  93. self.max_n_lyric_tokens = max_n_lyric_tokens
  94. self.n_genres = n_genres
  95. self._added_tokens_decoder = {0: unk_token}
  96. with open(artists_file, encoding="utf-8") as vocab_handle:
  97. self.artists_encoder = json.load(vocab_handle)
  98. with open(genres_file, encoding="utf-8") as vocab_handle:
  99. self.genres_encoder = json.load(vocab_handle)
  100. with open(lyrics_file, encoding="utf-8") as vocab_handle:
  101. self.lyrics_encoder = json.load(vocab_handle)
  102. oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
  103. # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters.
  104. if len(self.lyrics_encoder) == 79:
  105. oov = oov.replace(r"\-'", r"\-+'")
  106. self.out_of_vocab = regex.compile(oov)
  107. self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
  108. self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
  109. self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
  110. super().__init__(
  111. unk_token=unk_token,
  112. n_genres=n_genres,
  113. version=version,
  114. max_n_lyric_tokens=max_n_lyric_tokens,
  115. **kwargs,
  116. )
  117. @property
  118. def vocab_size(self):
  119. return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)
  120. def get_vocab(self):
  121. return {
  122. "artists_encoder": self.artists_encoder,
  123. "genres_encoder": self.genres_encoder,
  124. "lyrics_encoder": self.lyrics_encoder,
  125. }
  126. def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
  127. """Converts the artist, genre and lyrics tokens to their index using the vocabulary.
  128. The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
  129. the lyrics token sequence.
  130. """
  131. artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
  132. for genres in range(len(list_genres)):
  133. list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
  134. list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))
  135. lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
  136. return artists_id, list_genres, lyric_ids
  137. def _tokenize(self, lyrics):
  138. """
  139. Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
  140. vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
  141. Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
  142. """
  143. # only lyrics are not tokenized, but character based is easily handled
  144. return list(lyrics)
  145. def tokenize(self, artist, genre, lyrics, **kwargs):
  146. """
  147. Converts three strings in a 3 sequence of tokens using the tokenizer
  148. """
  149. artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
  150. lyrics = self._tokenize(lyrics)
  151. return artist, genre, lyrics
  152. def prepare_for_tokenization(
  153. self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
  154. ) -> Tuple[str, str, str, Dict[str, Any]]:
  155. """
  156. Performs any necessary transformations before tokenization.
  157. Args:
  158. artist (`str`):
  159. The artist name to prepare. This will mostly lower the string
  160. genres (`str`):
  161. The genre name to prepare. This will mostly lower the string.
  162. lyrics (`str`):
  163. The lyrics to prepare.
  164. is_split_into_words (`bool`, *optional*, defaults to `False`):
  165. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  166. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  167. which it will tokenize. This is useful for NER or token classification.
  168. """
  169. for idx in range(len(self.version)):
  170. if self.version[idx] == "v3":
  171. artists[idx] = artists[idx].lower()
  172. genres[idx] = [genres[idx].lower()]
  173. else:
  174. artists[idx] = self._normalize(artists[idx]) + ".v2"
  175. genres[idx] = [
  176. self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
  177. ] # split is for the full dictionary with combined genres
  178. if self.version[0] == "v2":
  179. self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
  180. vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
  181. self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
  182. self.vocab["<unk>"] = 0
  183. self.n_vocab = len(vocab) + 1
  184. self.lyrics_encoder = self.vocab
  185. self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
  186. self.lyrics_decoder[0] = ""
  187. else:
  188. self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")
  189. lyrics = self._run_strip_accents(lyrics)
  190. lyrics = lyrics.replace("\\", "\n")
  191. lyrics = self.out_of_vocab.sub("", lyrics), [], []
  192. return artists, genres, lyrics
  193. def _run_strip_accents(self, text):
  194. """Strips accents from a piece of text."""
  195. text = unicodedata.normalize("NFD", text)
  196. output = []
  197. for char in text:
  198. cat = unicodedata.category(char)
  199. if cat == "Mn":
  200. continue
  201. output.append(char)
  202. return "".join(output)
  203. def _normalize(self, text: str) -> str:
  204. """
  205. Normalizes the input text. This process is for the genres and the artist
  206. Args:
  207. text (`str`):
  208. Artist or Genre string to normalize
  209. """
  210. accepted = (
  211. [chr(i) for i in range(ord("a"), ord("z") + 1)]
  212. + [chr(i) for i in range(ord("A"), ord("Z") + 1)]
  213. + [chr(i) for i in range(ord("0"), ord("9") + 1)]
  214. + ["."]
  215. )
  216. accepted = frozenset(accepted)
  217. pattern = re.compile(r"_+")
  218. text = "".join([c if c in accepted else "_" for c in text.lower()])
  219. text = pattern.sub("_", text).strip("_")
  220. return text
  221. def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
  222. return " ".join(lyrics)
  223. def convert_to_tensors(
  224. self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
  225. ):
  226. """
  227. Convert the inner content to tensors.
  228. Args:
  229. tensor_type (`str` or [`~utils.TensorType`], *optional*):
  230. The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
  231. unset, no modification is done.
  232. prepend_batch_axis (`int`, *optional*, defaults to `False`):
  233. Whether or not to add the batch dimension during the conversion.
  234. """
  235. # Convert to TensorType
  236. if not isinstance(tensor_type, TensorType):
  237. tensor_type = TensorType(tensor_type)
  238. # Get a function reference for the correct framework
  239. if tensor_type == TensorType.TENSORFLOW:
  240. if not is_tf_available():
  241. raise ImportError(
  242. "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
  243. )
  244. import tensorflow as tf
  245. as_tensor = tf.constant
  246. is_tensor = tf.is_tensor
  247. elif tensor_type == TensorType.PYTORCH:
  248. if not is_torch_available():
  249. raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
  250. import torch
  251. as_tensor = torch.tensor
  252. is_tensor = torch.is_tensor
  253. elif tensor_type == TensorType.JAX:
  254. if not is_flax_available():
  255. raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
  256. import jax.numpy as jnp # noqa: F811
  257. as_tensor = jnp.array
  258. is_tensor = _is_jax
  259. else:
  260. as_tensor = np.asarray
  261. is_tensor = _is_numpy
  262. # Do the tensor conversion in batch
  263. try:
  264. if prepend_batch_axis:
  265. inputs = [inputs]
  266. if not is_tensor(inputs):
  267. inputs = as_tensor(inputs)
  268. except: # noqa E722
  269. raise ValueError(
  270. "Unable to create tensor, you should probably activate truncation and/or padding "
  271. "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
  272. )
  273. return inputs
  274. def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
  275. """Convert the raw string to a list of token ids
  276. Args:
  277. artist (`str`):
  278. Name of the artist.
  279. genres (`str`):
  280. List of genres that will be mixed to condition the audio
  281. lyrics (`str`, *optional*, defaults to `""`):
  282. Lyrics used to condition the generation
  283. """
  284. input_ids = [0, 0, 0]
  285. artist = [artist] * len(self.version)
  286. genres = [genres] * len(self.version)
  287. artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
  288. artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)
  289. attention_masks = [-INFINITY] * len(full_tokens[-1])
  290. input_ids = [
  291. self.convert_to_tensors(
  292. [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
  293. )
  294. for i in range(len(self.version))
  295. ]
  296. return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})
  297. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
  298. """
  299. Saves the tokenizer's vocabulary dictionary to the provided save_directory.
  300. Args:
  301. save_directory (`str`):
  302. A path to the directory where to saved. It will be created if it doesn't exist.
  303. filename_prefix (`Optional[str]`, *optional*):
  304. A prefix to add to the names of the files saved by the tokenizer.
  305. """
  306. if not os.path.isdir(save_directory):
  307. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  308. return
  309. artists_file = os.path.join(
  310. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
  311. )
  312. with open(artists_file, "w", encoding="utf-8") as f:
  313. f.write(json.dumps(self.artists_encoder, ensure_ascii=False))
  314. genres_file = os.path.join(
  315. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
  316. )
  317. with open(genres_file, "w", encoding="utf-8") as f:
  318. f.write(json.dumps(self.genres_encoder, ensure_ascii=False))
  319. lyrics_file = os.path.join(
  320. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
  321. )
  322. with open(lyrics_file, "w", encoding="utf-8") as f:
  323. f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))
  324. return (artists_file, genres_file, lyrics_file)
  325. def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
  326. """
  327. Converts an index (integer) in a token (str) using the vocab.
  328. Args:
  329. artists_index (`int`):
  330. Index of the artist in its corresponding dictionary.
  331. genres_index (`Union[List[int], int]`):
  332. Index of the genre in its corresponding dictionary.
  333. lyric_index (`List[int]`):
  334. List of character indices, which each correspond to a character.
  335. """
  336. artist = self.artists_decoder.get(artists_index)
  337. genres = [self.genres_decoder.get(genre) for genre in genres_index]
  338. lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
  339. return artist, genres, lyrics