tokenization_t5_fast.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # coding=utf-8
  2. # Copyright 2018 T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization class for model T5."""
  16. import os
  17. import re
  18. import warnings
  19. from shutil import copyfile
  20. from typing import List, Optional, Tuple
  21. from ...tokenization_utils_fast import PreTrainedTokenizerFast
  22. from ...utils import is_sentencepiece_available, logging
  23. if is_sentencepiece_available():
  24. from .tokenization_t5 import T5Tokenizer
  25. else:
  26. T5Tokenizer = None
  27. logger = logging.get_logger(__name__)
  28. VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
  29. # TODO(PVP) - this should be removed in Transformers v5
  30. class T5TokenizerFast(PreTrainedTokenizerFast):
  31. """
  32. Construct a "fast" T5 tokenizer (backed by HuggingFace's *tokenizers* library). Based on
  33. [Unigram](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models).
  34. This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
  35. refer to this superclass for more information regarding those methods.
  36. Args:
  37. vocab_file (`str`):
  38. [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
  39. contains the vocabulary necessary to instantiate a tokenizer.
  40. eos_token (`str`, *optional*, defaults to `"</s>"`):
  41. The end of sequence token.
  42. <Tip>
  43. When building a sequence using special tokens, this is not the token that is used for the end of sequence.
  44. The token used is the `sep_token`.
  45. </Tip>
  46. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  47. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  48. token instead.
  49. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  50. The token used for padding, for example when batching sequences of different lengths.
  51. extra_ids (`int`, *optional*, defaults to 100):
  52. Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
  53. "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
  54. calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
  55. additional_special_tokens (`List[str]`, *optional*):
  56. Additional special tokens used by the tokenizer.
  57. add_prefix_space (`bool`, *optional*):
  58. Whether or not the tokenizer should automatically add a prefix space
  59. from_slow (`book`, *optional*, defaults to `False`):
  60. Whether or not the tokenizer should be converted from a slow one. If `add_prefix_space` is set, this will be set to `True`.
  61. """
  62. vocab_files_names = VOCAB_FILES_NAMES
  63. model_input_names = ["input_ids", "attention_mask"]
  64. slow_tokenizer_class = T5Tokenizer
  65. prefix_tokens: List[int] = []
  66. def __init__(
  67. self,
  68. vocab_file=None,
  69. tokenizer_file=None,
  70. eos_token="</s>",
  71. unk_token="<unk>",
  72. pad_token="<pad>",
  73. extra_ids=100,
  74. additional_special_tokens=None,
  75. add_prefix_space=None,
  76. **kwargs,
  77. ):
  78. # Add extra_ids to the special token list
  79. if additional_special_tokens is not None:
  80. extra_tokens = [x for x in additional_special_tokens if "<extra_id_" in str(x)]
  81. if len(extra_tokens) < 1:
  82. additional_special_tokens += [f"<extra_id_{i}>" for i in range(extra_ids)]
  83. elif extra_ids > 0 and extra_ids != len(extra_tokens):
  84. raise ValueError(
  85. f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
  86. " provided to T5Tokenizer. In this case the additional_special_tokens must include the extra_ids"
  87. " tokens"
  88. )
  89. else:
  90. extra_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
  91. additional_special_tokens = extra_tokens
  92. if add_prefix_space is not None:
  93. logger.warning_once(
  94. "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
  95. )
  96. kwargs["from_slow"] = True
  97. super().__init__(
  98. vocab_file,
  99. tokenizer_file=tokenizer_file,
  100. eos_token=eos_token,
  101. unk_token=unk_token,
  102. pad_token=pad_token,
  103. extra_ids=extra_ids,
  104. additional_special_tokens=additional_special_tokens,
  105. **kwargs,
  106. )
  107. self.vocab_file = vocab_file
  108. self._extra_ids = extra_ids
  109. @property
  110. def can_save_slow_tokenizer(self) -> bool:
  111. return os.path.isfile(self.vocab_file) if self.vocab_file else False
  112. @staticmethod
  113. def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
  114. if pretrained_model_name_or_path in T5TokenizerFast.max_model_input_sizes:
  115. deprecated_max_model_length = T5TokenizerFast.max_model_input_sizes[pretrained_model_name_or_path]
  116. if init_max_model_length is not None and init_max_model_length != max_model_length:
  117. return init_max_model_length
  118. elif init_max_model_length is None:
  119. warnings.warn(
  120. "This tokenizer was incorrectly instantiated with a model max length of"
  121. f" {deprecated_max_model_length} which will be corrected in Transformers v5.\nFor now, this"
  122. " behavior is kept to avoid breaking backwards compatibility when padding/encoding with"
  123. " `truncation is True`.\n- Be aware that you SHOULD NOT rely on"
  124. f" {pretrained_model_name_or_path} automatically truncating your input to"
  125. f" {deprecated_max_model_length} when padding/encoding.\n- If you want to encode/pad to sequences"
  126. f" longer than {deprecated_max_model_length} you can either instantiate this tokenizer with"
  127. " `model_max_length` or pass `max_length` when encoding/padding.\n- To avoid this warning, please"
  128. " instantiate this tokenizer with `model_max_length` set to your preferred value.",
  129. FutureWarning,
  130. )
  131. return max_model_length
  132. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
  133. if not self.can_save_slow_tokenizer:
  134. raise ValueError(
  135. "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
  136. "tokenizer."
  137. )
  138. if not os.path.isdir(save_directory):
  139. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  140. return
  141. out_vocab_file = os.path.join(
  142. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  143. )
  144. if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
  145. copyfile(self.vocab_file, out_vocab_file)
  146. logger.info(f"Copy vocab file to {out_vocab_file}")
  147. return (out_vocab_file,)
  148. def build_inputs_with_special_tokens(
  149. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  150. ) -> List[int]:
  151. """
  152. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  153. adding special tokens. A sequence has the following format:
  154. - single sequence: `X </s>`
  155. - pair of sequences: `A </s> B </s>`
  156. Args:
  157. token_ids_0 (`List[int]`):
  158. List of IDs to which the special tokens will be added.
  159. token_ids_1 (`List[int]`, *optional*):
  160. Optional second list of IDs for sequence pairs.
  161. Returns:
  162. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  163. """
  164. token_ids_0 = token_ids_0 + [self.eos_token_id]
  165. if token_ids_1 is None:
  166. return self.prefix_tokens + token_ids_0
  167. else:
  168. token_ids_1 = token_ids_1 + [self.eos_token_id]
  169. return self.prefix_tokens + token_ids_0 + token_ids_1
  170. def create_token_type_ids_from_sequences(
  171. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  172. ) -> List[int]:
  173. """
  174. Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
  175. use of token type ids, therefore a list of zeros is returned.
  176. Args:
  177. token_ids_0 (`List[int]`):
  178. List of IDs.
  179. token_ids_1 (`List[int]`, *optional*):
  180. Optional second list of IDs for sequence pairs.
  181. Returns:
  182. `List[int]`: List of zeros.
  183. """
  184. eos = [self.eos_token_id]
  185. if token_ids_1 is None:
  186. return len(token_ids_0 + eos) * [0]
  187. return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
  188. def get_sentinel_tokens(self):
  189. return list(
  190. set(filter(lambda x: bool(re.search(r"<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
  191. )
  192. def get_sentinel_token_ids(self):
  193. return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]