tokenization_led_fast.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # coding=utf-8
  2. # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for LED."""
  16. import json
  17. from typing import Dict, List, Optional, Tuple, Union
  18. from tokenizers import pre_tokenizers, processors
  19. from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput
  20. from ...tokenization_utils_fast import PreTrainedTokenizerFast
  21. from ...utils import PaddingStrategy, logging
  22. from .tokenization_led import LEDTokenizer
  23. logger = logging.get_logger(__name__)
  24. VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
  25. class LEDTokenizerFast(PreTrainedTokenizerFast):
  26. r"""
  27. Construct a "fast" LED tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
  28. using byte-level Byte-Pair-Encoding.
  29. This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
  30. be encoded differently whether it is at the beginning of the sentence (without space) or not:
  31. ```python
  32. >>> from transformers import LEDTokenizerFast
  33. >>> tokenizer = LEDTokenizerFast.from_pretrained("allenai/led-base-16384")
  34. >>> tokenizer("Hello world")["input_ids"]
  35. [0, 31414, 232, 2]
  36. >>> tokenizer(" Hello world")["input_ids"]
  37. [0, 20920, 232, 2]
  38. ```
  39. You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
  40. call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
  41. <Tip>
  42. When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
  43. </Tip>
  44. This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
  45. refer to this superclass for more information regarding those methods.
  46. Args:
  47. vocab_file (`str`):
  48. Path to the vocabulary file.
  49. merges_file (`str`):
  50. Path to the merges file.
  51. errors (`str`, *optional*, defaults to `"replace"`):
  52. Paradigm to follow when decoding bytes to UTF-8. See
  53. [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
  54. bos_token (`str`, *optional*, defaults to `"<s>"`):
  55. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  56. <Tip>
  57. When building a sequence using special tokens, this is not the token that is used for the beginning of
  58. sequence. The token used is the `cls_token`.
  59. </Tip>
  60. eos_token (`str`, *optional*, defaults to `"</s>"`):
  61. The end of sequence token.
  62. <Tip>
  63. When building a sequence using special tokens, this is not the token that is used for the end of sequence.
  64. The token used is the `sep_token`.
  65. </Tip>
  66. sep_token (`str`, *optional*, defaults to `"</s>"`):
  67. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  68. sequence classification or for a text and a question for question answering. It is also used as the last
  69. token of a sequence built with special tokens.
  70. cls_token (`str`, *optional*, defaults to `"<s>"`):
  71. The classifier token which is used when doing sequence classification (classification of the whole sequence
  72. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  73. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  74. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  75. token instead.
  76. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  77. The token used for padding, for example when batching sequences of different lengths.
  78. mask_token (`str`, *optional*, defaults to `"<mask>"`):
  79. The token used for masking values. This is the token used when training this model with masked language
  80. modeling. This is the token which the model will try to predict.
  81. add_prefix_space (`bool`, *optional*, defaults to `False`):
  82. Whether or not to add an initial space to the input. This allows to treat the leading word just as any
  83. other word. (LED tokenizer detect beginning of words by the preceding space).
  84. trim_offsets (`bool`, *optional*, defaults to `True`):
  85. Whether the post processing step should trim offsets to avoid including whitespaces.
  86. """
  87. vocab_files_names = VOCAB_FILES_NAMES
  88. slow_tokenizer_class = LEDTokenizer
  89. model_input_names = ["input_ids", "attention_mask"]
  90. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.__init__
  91. def __init__(
  92. self,
  93. vocab_file=None,
  94. merges_file=None,
  95. tokenizer_file=None,
  96. errors="replace",
  97. bos_token="<s>",
  98. eos_token="</s>",
  99. sep_token="</s>",
  100. cls_token="<s>",
  101. unk_token="<unk>",
  102. pad_token="<pad>",
  103. mask_token="<mask>",
  104. add_prefix_space=False,
  105. trim_offsets=True,
  106. **kwargs,
  107. ):
  108. # we have to specify that this tokens is special otherwise adding it will reset the normalized flag to `False` in `add_special_tokens`
  109. mask_token = (
  110. AddedToken(mask_token, lstrip=True, normalized=True, special=True)
  111. if isinstance(mask_token, str)
  112. else mask_token
  113. )
  114. super().__init__(
  115. vocab_file,
  116. merges_file,
  117. tokenizer_file=tokenizer_file,
  118. errors=errors,
  119. bos_token=bos_token,
  120. eos_token=eos_token,
  121. sep_token=sep_token,
  122. cls_token=cls_token,
  123. unk_token=unk_token,
  124. pad_token=pad_token,
  125. mask_token=mask_token,
  126. add_prefix_space=add_prefix_space,
  127. trim_offsets=trim_offsets,
  128. **kwargs,
  129. )
  130. pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
  131. if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
  132. pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
  133. pre_tok_state["add_prefix_space"] = add_prefix_space
  134. self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
  135. self.add_prefix_space = add_prefix_space
  136. # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
  137. tokenizer_component = "post_processor"
  138. tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
  139. if tokenizer_component_instance:
  140. state = json.loads(tokenizer_component_instance.__getstate__())
  141. # The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
  142. if "sep" in state:
  143. state["sep"] = tuple(state["sep"])
  144. if "cls" in state:
  145. state["cls"] = tuple(state["cls"])
  146. changes_to_apply = False
  147. if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
  148. state["add_prefix_space"] = add_prefix_space
  149. changes_to_apply = True
  150. if state.get("trim_offsets", trim_offsets) != trim_offsets:
  151. state["trim_offsets"] = trim_offsets
  152. changes_to_apply = True
  153. if changes_to_apply:
  154. component_class = getattr(processors, state.pop("type"))
  155. new_value = component_class(**state)
  156. setattr(self.backend_tokenizer, tokenizer_component, new_value)
  157. @property
  158. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.mask_token with BART->LED
  159. def mask_token(self) -> str:
  160. """
  161. `str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
  162. having been set.
  163. LED tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
  164. comprise the space before the *<mask>*.
  165. """
  166. if self._mask_token is None:
  167. if self.verbose:
  168. logger.error("Using mask_token, but it is not set yet.")
  169. return None
  170. return str(self._mask_token)
  171. @mask_token.setter
  172. def mask_token(self, value):
  173. """
  174. Overriding the default behavior of the mask token to have it eat the space before it.
  175. This is needed to preserve backward compatibility with all the previously used models based on LED.
  176. """
  177. # Mask token behave like a normal word, i.e. include the space before it
  178. # So we set lstrip to True
  179. value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
  180. self._mask_token = value
  181. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._batch_encode_plus
  182. def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
  183. is_split_into_words = kwargs.get("is_split_into_words", False)
  184. if is_split_into_words and not self.add_prefix_space:
  185. raise ValueError(
  186. f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
  187. "to use it with pretokenized inputs."
  188. )
  189. return super()._batch_encode_plus(*args, **kwargs)
  190. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast._encode_plus
  191. def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
  192. is_split_into_words = kwargs.get("is_split_into_words", False)
  193. if is_split_into_words and not self.add_prefix_space:
  194. raise ValueError(
  195. f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
  196. "to use it with pretokenized inputs."
  197. )
  198. return super()._encode_plus(*args, **kwargs)
  199. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.save_vocabulary
  200. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
  201. files = self._tokenizer.model.save(save_directory, name=filename_prefix)
  202. return tuple(files)
  203. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.build_inputs_with_special_tokens
  204. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  205. output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
  206. if token_ids_1 is None:
  207. return output
  208. return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
  209. # Copied from transformers.models.bart.tokenization_bart_fast.BartTokenizerFast.create_token_type_ids_from_sequences with BART->LED
  210. def create_token_type_ids_from_sequences(
  211. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  212. ) -> List[int]:
  213. """
  214. Create a mask from the two sequences passed to be used in a sequence-pair classification task. LED does not
  215. make use of token type ids, therefore a list of zeros is returned.
  216. Args:
  217. token_ids_0 (`List[int]`):
  218. List of IDs.
  219. token_ids_1 (`List[int]`, *optional*):
  220. Optional second list of IDs for sequence pairs.
  221. Returns:
  222. `List[int]`: List of zeros.
  223. """
  224. sep = [self.sep_token_id]
  225. cls = [self.cls_token_id]
  226. if token_ids_1 is None:
  227. return len(cls + token_ids_0 + sep) * [0]
  228. return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
  229. # Copied from transformers.models.led.tokenization_led.LEDTokenizer._pad
  230. def _pad(
  231. self,
  232. encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
  233. max_length: Optional[int] = None,
  234. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  235. pad_to_multiple_of: Optional[int] = None,
  236. padding_side: Optional[bool] = None,
  237. return_attention_mask: Optional[bool] = None,
  238. ) -> dict:
  239. encoded_inputs = super()._pad(
  240. encoded_inputs=encoded_inputs,
  241. max_length=max_length,
  242. padding_strategy=padding_strategy,
  243. pad_to_multiple_of=pad_to_multiple_of,
  244. padding_side=padding_side,
  245. return_attention_mask=return_attention_mask,
  246. )
  247. # Load from model defaults
  248. if return_attention_mask is None:
  249. return_attention_mask = "attention_mask" in self.model_input_names
  250. if return_attention_mask and "global_attention_mask" in encoded_inputs:
  251. required_input = encoded_inputs[self.model_input_names[0]]
  252. # `global_attention_mask` need to have the same length as other (sequential) inputs.
  253. needs_to_be_padded = len(encoded_inputs["global_attention_mask"]) != len(required_input)
  254. if needs_to_be_padded:
  255. difference = len(required_input) - len(encoded_inputs["global_attention_mask"])
  256. if self.padding_side == "right":
  257. # Use `-1` since `0` in `global_attention_mask` means `local attention` instead of `not to attend`
  258. encoded_inputs["global_attention_mask"] = (
  259. encoded_inputs["global_attention_mask"] + [-1] * difference
  260. )
  261. elif self.padding_side == "left":
  262. encoded_inputs["global_attention_mask"] = [-1] * difference + encoded_inputs[
  263. "global_attention_mask"
  264. ]
  265. else:
  266. raise ValueError("Invalid padding strategy:" + str(self.padding_side))
  267. return encoded_inputs