tokenization_utils_fast.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895
  1. # coding=utf-8
  2. # Copyright 2020 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
  17. see tokenization_utils.py
  18. """
  19. import copy
  20. import json
  21. import os
  22. from collections import defaultdict
  23. from typing import Any, Dict, List, Optional, Tuple, Union
  24. import tokenizers.pre_tokenizers as pre_tokenizers_fast
  25. from tokenizers import Encoding as EncodingFast
  26. from tokenizers import Tokenizer as TokenizerFast
  27. from tokenizers.decoders import Decoder as DecoderFast
  28. from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
  29. from .convert_slow_tokenizer import convert_slow_tokenizer
  30. from .integrations.ggml import convert_gguf_tokenizer
  31. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  32. from .tokenization_utils import PreTrainedTokenizer
  33. from .tokenization_utils_base import (
  34. INIT_TOKENIZER_DOCSTRING,
  35. AddedToken,
  36. BatchEncoding,
  37. PreTokenizedInput,
  38. PreTokenizedInputPair,
  39. PreTrainedTokenizerBase,
  40. SpecialTokensMixin,
  41. TextInput,
  42. TextInputPair,
  43. TruncationStrategy,
  44. )
  45. from .utils import PaddingStrategy, add_end_docstrings, logging
  46. logger = logging.get_logger(__name__)
  47. # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
  48. TOKENIZER_FILE = "tokenizer.json"
  49. SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
  50. TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
  51. TIKTOKEN_VOCAB_FILE = "tokenizer.model"
  52. # Slow tokenizers have an additional added tokens files
  53. ADDED_TOKENS_FILE = "added_tokens.json"
  54. INIT_TOKENIZER_DOCSTRING += """
  55. tokenizer_object ([`tokenizers.Tokenizer`]):
  56. A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
  57. tokenizers](../fast_tokenizers) for more information.
  58. tokenizer_file ([`str`]):
  59. A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
  60. tokenizers.
  61. """
  62. MODEL_TO_TRAINER_MAPPING = {
  63. "BPE": BpeTrainer,
  64. "Unigram": UnigramTrainer,
  65. "WordLevel": WordLevelTrainer,
  66. "WordPiece": WordPieceTrainer,
  67. }
  68. VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE}
  69. @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
  70. class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
  71. """
  72. Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
  73. Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
  74. Handles all the shared methods for tokenization and special tokens, as well as methods for
  75. downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary.
  76. This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the
  77. specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
  78. """
  79. vocab_files_names = VOCAB_FILES_NAMES
  80. slow_tokenizer_class: PreTrainedTokenizer = None
  81. def __init__(self, *args, **kwargs):
  82. tokenizer_object = kwargs.pop("tokenizer_object", None)
  83. slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
  84. gguf_file = kwargs.pop("gguf_file", None)
  85. fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
  86. from_slow = kwargs.pop("from_slow", False)
  87. added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
  88. if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
  89. raise ValueError(
  90. "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
  91. "have sentencepiece installed."
  92. )
  93. if tokenizer_object is not None:
  94. fast_tokenizer = copy.deepcopy(tokenizer_object)
  95. elif fast_tokenizer_file is not None and not from_slow:
  96. # We have a serialization from tokenizers which let us directly build the backend
  97. fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
  98. elif slow_tokenizer:
  99. # We need to convert a slow tokenizer to build the backend
  100. fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
  101. elif gguf_file is not None:
  102. # We need to convert a slow tokenizer to build the backend
  103. gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
  104. architecture = gguf_param["config"]["model_type"]
  105. tokenizer_dict = gguf_param["tokenizer"]
  106. tokenizer_config = gguf_param["tokenizer_config"]
  107. fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)
  108. kwargs.update(tokenizer_config)
  109. if len(additional_kwargs) > 0:
  110. kwargs.update(additional_kwargs)
  111. elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
  112. # We need to create and convert a slow tokenizer to build the backend
  113. slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
  114. fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
  115. elif not slow_tokenizer:
  116. # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
  117. self.vocab_file = kwargs.get("vocab_file", None)
  118. self.additional_special_tokens = kwargs.get("additional_special_tokens", [])
  119. fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True)
  120. slow_tokenizer = None
  121. else:
  122. raise ValueError(
  123. "Couldn't instantiate the backend tokenizer from one of: \n"
  124. "(1) a `tokenizers` library serialization file, \n"
  125. "(2) a slow tokenizer instance to convert or \n"
  126. "(3) an equivalent slow tokenizer class to instantiate and convert. \n"
  127. "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
  128. )
  129. self._tokenizer = fast_tokenizer
  130. if slow_tokenizer is not None:
  131. kwargs.update(slow_tokenizer.init_kwargs)
  132. self._decode_use_source_tokenizer = False
  133. _truncation = self._tokenizer.truncation
  134. if _truncation is not None:
  135. self._tokenizer.enable_truncation(**_truncation)
  136. kwargs.setdefault("max_length", _truncation["max_length"])
  137. kwargs.setdefault("truncation_side", _truncation["direction"])
  138. kwargs.setdefault("stride", _truncation["stride"])
  139. kwargs.setdefault("truncation_strategy", _truncation["strategy"])
  140. else:
  141. self._tokenizer.no_truncation()
  142. _padding = self._tokenizer.padding
  143. if _padding is not None:
  144. self._tokenizer.enable_padding(**_padding)
  145. kwargs.setdefault("pad_token", _padding["pad_token"])
  146. kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
  147. kwargs.setdefault("padding_side", _padding["direction"])
  148. kwargs.setdefault("max_length", _padding["length"])
  149. kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
  150. # We call this after having initialized the backend tokenizer because we update it.
  151. super().__init__(**kwargs)
  152. self._tokenizer.encode_special_tokens = self.split_special_tokens
  153. added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
  154. tokens_to_add = [
  155. token
  156. for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
  157. if hash(repr(token)) not in added_tokens_decoder_hash
  158. ]
  159. encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add]
  160. # if some of the special tokens are strings, we check if we don't already have a token
  161. tokens_to_add += [
  162. token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add
  163. ]
  164. if len(tokens_to_add) > 0:
  165. tokens = []
  166. special_tokens = self.all_special_tokens
  167. for token in tokens_to_add:
  168. is_special = (
  169. (token.special or str(token) in special_tokens)
  170. if isinstance(token, AddedToken)
  171. else str(token) in special_tokens
  172. )
  173. if isinstance(token, str):
  174. token = AddedToken(token, special=is_special)
  175. else:
  176. token.special = is_special
  177. tokens.append(token)
  178. if tokens:
  179. self.add_tokens(tokens)
  180. @property
  181. def is_fast(self) -> bool:
  182. return True
  183. @property
  184. def can_save_slow_tokenizer(self) -> bool:
  185. """
  186. `bool`: Whether or not the slow tokenizer can be saved. Usually for sentencepiece based slow tokenizer, this
  187. can only be `True` if the original `"sentencepiece.model"` was not deleted.
  188. """
  189. return True
  190. @property
  191. def vocab_size(self) -> int:
  192. """
  193. `int`: Size of the base vocabulary (without the added tokens).
  194. """
  195. return self._tokenizer.get_vocab_size(with_added_tokens=False)
  196. def get_vocab(self) -> Dict[str, int]:
  197. return self._tokenizer.get_vocab(with_added_tokens=True)
  198. @property
  199. def vocab(self) -> Dict[str, int]:
  200. return self.get_vocab()
  201. @property
  202. def added_tokens_encoder(self) -> Dict[str, int]:
  203. """
  204. Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
  205. optimisation in `self._added_tokens_encoder` for the slow tokenizers.
  206. """
  207. return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
  208. @property
  209. def added_tokens_decoder(self) -> Dict[int, AddedToken]:
  210. """
  211. Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
  212. Returns:
  213. `Dict[str, int]`: The added tokens.
  214. """
  215. return self._tokenizer.get_added_tokens_decoder()
  216. def get_added_vocab(self) -> Dict[str, int]:
  217. """
  218. Returns the added tokens in the vocabulary as a dictionary of token to index.
  219. Returns:
  220. `Dict[str, int]`: The added tokens.
  221. """
  222. return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
  223. def __len__(self) -> int:
  224. """
  225. Size of the full vocabulary with the added tokens.
  226. """
  227. return self._tokenizer.get_vocab_size(with_added_tokens=True)
  228. @property
  229. def backend_tokenizer(self) -> TokenizerFast:
  230. """
  231. `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
  232. """
  233. return self._tokenizer
  234. @property
  235. def decoder(self) -> DecoderFast:
  236. """
  237. `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer.
  238. """
  239. return self._tokenizer.decoder
  240. def _convert_encoding(
  241. self,
  242. encoding: EncodingFast,
  243. return_token_type_ids: Optional[bool] = None,
  244. return_attention_mask: Optional[bool] = None,
  245. return_overflowing_tokens: bool = False,
  246. return_special_tokens_mask: bool = False,
  247. return_offsets_mapping: bool = False,
  248. return_length: bool = False,
  249. verbose: bool = True,
  250. ) -> Tuple[Dict[str, Any], List[EncodingFast]]:
  251. """
  252. Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
  253. of encodings, take care of building a batch from overflowing tokens.
  254. Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
  255. lists (overflows) of lists (tokens).
  256. Output shape: (overflows, sequence length)
  257. """
  258. if return_token_type_ids is None:
  259. return_token_type_ids = "token_type_ids" in self.model_input_names
  260. if return_attention_mask is None:
  261. return_attention_mask = "attention_mask" in self.model_input_names
  262. if return_overflowing_tokens and encoding.overflowing is not None:
  263. encodings = [encoding] + encoding.overflowing
  264. else:
  265. encodings = [encoding]
  266. encoding_dict = defaultdict(list)
  267. for e in encodings:
  268. encoding_dict["input_ids"].append(e.ids)
  269. if return_token_type_ids:
  270. encoding_dict["token_type_ids"].append(e.type_ids)
  271. if return_attention_mask:
  272. encoding_dict["attention_mask"].append(e.attention_mask)
  273. if return_special_tokens_mask:
  274. encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
  275. if return_offsets_mapping:
  276. encoding_dict["offset_mapping"].append(e.offsets)
  277. if return_length:
  278. encoding_dict["length"].append(len(e.ids))
  279. return encoding_dict, encodings
  280. def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
  281. """
  282. Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
  283. vocabulary.
  284. Args:
  285. tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
  286. Returns:
  287. `int` or `List[int]`: The token id or list of token ids.
  288. """
  289. if tokens is None:
  290. return None
  291. if isinstance(tokens, str):
  292. return self._convert_token_to_id_with_added_voc(tokens)
  293. return [self._convert_token_to_id_with_added_voc(token) for token in tokens]
  294. def _convert_token_to_id_with_added_voc(self, token: str) -> int:
  295. index = self._tokenizer.token_to_id(token)
  296. if index is None:
  297. return self.unk_token_id
  298. return index
  299. def _convert_id_to_token(self, index: int) -> Optional[str]:
  300. return self._tokenizer.id_to_token(int(index))
  301. def _add_tokens(self, new_tokens: List[Union[str, AddedToken]], special_tokens=False) -> int:
  302. if special_tokens:
  303. return self._tokenizer.add_special_tokens(new_tokens)
  304. return self._tokenizer.add_tokens(new_tokens)
  305. def num_special_tokens_to_add(self, pair: bool = False) -> int:
  306. """
  307. Returns the number of added tokens when encoding a sequence with special tokens.
  308. <Tip>
  309. This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
  310. this inside your training loop.
  311. </Tip>
  312. Args:
  313. pair (`bool`, *optional*, defaults to `False`):
  314. Whether the number of added tokens should be computed in the case of a sequence pair or a single
  315. sequence.
  316. Returns:
  317. `int`: Number of special tokens added to sequences.
  318. """
  319. return self._tokenizer.num_special_tokens_to_add(pair)
  320. def convert_ids_to_tokens(
  321. self, ids: Union[int, List[int]], skip_special_tokens: bool = False
  322. ) -> Union[str, List[str]]:
  323. """
  324. Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
  325. added tokens.
  326. Args:
  327. ids (`int` or `List[int]`):
  328. The token id (or token ids) to convert to tokens.
  329. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  330. Whether or not to remove special tokens in the decoding.
  331. Returns:
  332. `str` or `List[str]`: The decoded token(s).
  333. """
  334. if isinstance(ids, int):
  335. return self._tokenizer.id_to_token(ids)
  336. tokens = []
  337. for index in ids:
  338. index = int(index)
  339. if skip_special_tokens and index in self.all_special_ids:
  340. continue
  341. tokens.append(self._tokenizer.id_to_token(index))
  342. return tokens
  343. def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
  344. return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
  345. def set_truncation_and_padding(
  346. self,
  347. padding_strategy: PaddingStrategy,
  348. truncation_strategy: TruncationStrategy,
  349. max_length: int,
  350. stride: int,
  351. pad_to_multiple_of: Optional[int],
  352. padding_side: Optional[bool],
  353. ):
  354. """
  355. Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers
  356. library) and restore the tokenizer settings afterwards.
  357. The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a
  358. padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed
  359. section.
  360. Args:
  361. padding_strategy ([`~utils.PaddingStrategy`]):
  362. The kind of padding that will be applied to the input
  363. truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]):
  364. The kind of truncation that will be applied to the input
  365. max_length (`int`):
  366. The maximum size of a sequence.
  367. stride (`int`):
  368. The stride to use when handling overflow.
  369. pad_to_multiple_of (`int`, *optional*):
  370. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  371. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  372. padding_side (`str`, *optional*):
  373. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  374. Default value is picked from the class attribute of the same name.
  375. """
  376. _truncation = self._tokenizer.truncation
  377. _padding = self._tokenizer.padding
  378. # Set truncation and padding on the backend tokenizer
  379. if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
  380. if _truncation is not None:
  381. self._tokenizer.no_truncation()
  382. else:
  383. target = {
  384. "max_length": max_length,
  385. "stride": stride,
  386. "strategy": truncation_strategy.value,
  387. "direction": self.truncation_side,
  388. }
  389. # _truncation might contain more keys that the target `transformers`
  390. # supports. Use only the target keys to trigger `enable_truncation`.
  391. # This should enable this code to works on various `tokenizers`
  392. # targets.
  393. if _truncation is None:
  394. current = None
  395. else:
  396. current = {k: _truncation.get(k, None) for k in target}
  397. if current != target:
  398. self._tokenizer.enable_truncation(**target)
  399. if padding_strategy == PaddingStrategy.DO_NOT_PAD:
  400. if _padding is not None:
  401. self._tokenizer.no_padding()
  402. else:
  403. length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
  404. target = {
  405. "length": length,
  406. "direction": padding_side if padding_side is not None else self.padding_side,
  407. "pad_id": self.pad_token_id,
  408. "pad_token": self.pad_token,
  409. "pad_type_id": self.pad_token_type_id,
  410. "pad_to_multiple_of": pad_to_multiple_of,
  411. }
  412. if _padding != target:
  413. self._tokenizer.enable_padding(**target)
  414. def _batch_encode_plus(
  415. self,
  416. batch_text_or_text_pairs: Union[
  417. List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
  418. ],
  419. add_special_tokens: bool = True,
  420. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  421. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  422. max_length: Optional[int] = None,
  423. stride: int = 0,
  424. is_split_into_words: bool = False,
  425. pad_to_multiple_of: Optional[int] = None,
  426. padding_side: Optional[bool] = None,
  427. return_tensors: Optional[str] = None,
  428. return_token_type_ids: Optional[bool] = None,
  429. return_attention_mask: Optional[bool] = None,
  430. return_overflowing_tokens: bool = False,
  431. return_special_tokens_mask: bool = False,
  432. return_offsets_mapping: bool = False,
  433. return_length: bool = False,
  434. verbose: bool = True,
  435. split_special_tokens: bool = False,
  436. ) -> BatchEncoding:
  437. if not isinstance(batch_text_or_text_pairs, (tuple, list)):
  438. raise TypeError(
  439. f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
  440. )
  441. # Set the truncation and padding strategy and restore the initial configuration
  442. self.set_truncation_and_padding(
  443. padding_strategy=padding_strategy,
  444. truncation_strategy=truncation_strategy,
  445. max_length=max_length,
  446. stride=stride,
  447. pad_to_multiple_of=pad_to_multiple_of,
  448. padding_side=padding_side,
  449. )
  450. if self._tokenizer.encode_special_tokens != split_special_tokens:
  451. self._tokenizer.encode_special_tokens = split_special_tokens
  452. encodings = self._tokenizer.encode_batch(
  453. batch_text_or_text_pairs,
  454. add_special_tokens=add_special_tokens,
  455. is_pretokenized=is_split_into_words,
  456. )
  457. # Convert encoding to dict
  458. # `Tokens` has type: Tuple[
  459. # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
  460. # List[EncodingFast]
  461. # ]
  462. # with nested dimensions corresponding to batch, overflows, sequence length
  463. tokens_and_encodings = [
  464. self._convert_encoding(
  465. encoding=encoding,
  466. return_token_type_ids=return_token_type_ids,
  467. return_attention_mask=return_attention_mask,
  468. return_overflowing_tokens=return_overflowing_tokens,
  469. return_special_tokens_mask=return_special_tokens_mask,
  470. return_offsets_mapping=return_offsets_mapping,
  471. return_length=return_length,
  472. verbose=verbose,
  473. )
  474. for encoding in encodings
  475. ]
  476. # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
  477. # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
  478. # (we say ~ because the number of overflow varies with the example in the batch)
  479. #
  480. # To match each overflowing sample with the original sample in the batch
  481. # we add an overflow_to_sample_mapping array (see below)
  482. sanitized_tokens = {}
  483. for key in tokens_and_encodings[0][0].keys():
  484. stack = [e for item, _ in tokens_and_encodings for e in item[key]]
  485. sanitized_tokens[key] = stack
  486. sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
  487. # If returning overflowing tokens, we need to return a mapping
  488. # from the batch idx to the original sample
  489. if return_overflowing_tokens:
  490. overflow_to_sample_mapping = []
  491. for i, (toks, _) in enumerate(tokens_and_encodings):
  492. overflow_to_sample_mapping += [i] * len(toks["input_ids"])
  493. sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
  494. for input_ids in sanitized_tokens["input_ids"]:
  495. self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
  496. return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
  497. def _encode_plus(
  498. self,
  499. text: Union[TextInput, PreTokenizedInput],
  500. text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
  501. add_special_tokens: bool = True,
  502. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  503. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  504. max_length: Optional[int] = None,
  505. stride: int = 0,
  506. is_split_into_words: bool = False,
  507. pad_to_multiple_of: Optional[int] = None,
  508. padding_side: Optional[bool] = None,
  509. return_tensors: Optional[bool] = None,
  510. return_token_type_ids: Optional[bool] = None,
  511. return_attention_mask: Optional[bool] = None,
  512. return_overflowing_tokens: bool = False,
  513. return_special_tokens_mask: bool = False,
  514. return_offsets_mapping: bool = False,
  515. return_length: bool = False,
  516. verbose: bool = True,
  517. split_special_tokens: bool = False,
  518. **kwargs,
  519. ) -> BatchEncoding:
  520. batched_input = [(text, text_pair)] if text_pair else [text]
  521. batched_output = self._batch_encode_plus(
  522. batched_input,
  523. is_split_into_words=is_split_into_words,
  524. add_special_tokens=add_special_tokens,
  525. padding_strategy=padding_strategy,
  526. truncation_strategy=truncation_strategy,
  527. max_length=max_length,
  528. stride=stride,
  529. pad_to_multiple_of=pad_to_multiple_of,
  530. padding_side=padding_side,
  531. return_tensors=return_tensors,
  532. return_token_type_ids=return_token_type_ids,
  533. return_attention_mask=return_attention_mask,
  534. return_overflowing_tokens=return_overflowing_tokens,
  535. return_special_tokens_mask=return_special_tokens_mask,
  536. return_offsets_mapping=return_offsets_mapping,
  537. return_length=return_length,
  538. verbose=verbose,
  539. split_special_tokens=split_special_tokens,
  540. **kwargs,
  541. )
  542. # Return tensor is None, then we can remove the leading batch axis
  543. # Overflowing tokens are returned as a batch of output so we keep them in this case
  544. if return_tensors is None and not return_overflowing_tokens:
  545. batched_output = BatchEncoding(
  546. {
  547. key: value[0] if len(value) > 0 and isinstance(value[0], list) else value
  548. for key, value in batched_output.items()
  549. },
  550. batched_output.encodings,
  551. )
  552. self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
  553. return batched_output
  554. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  555. return self.backend_tokenizer.decoder.decode(tokens)
  556. def _decode(
  557. self,
  558. token_ids: Union[int, List[int]],
  559. skip_special_tokens: bool = False,
  560. clean_up_tokenization_spaces: bool = None,
  561. **kwargs,
  562. ) -> str:
  563. self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
  564. if isinstance(token_ids, int):
  565. token_ids = [token_ids]
  566. text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
  567. clean_up_tokenization_spaces = (
  568. clean_up_tokenization_spaces
  569. if clean_up_tokenization_spaces is not None
  570. else self.clean_up_tokenization_spaces
  571. )
  572. if clean_up_tokenization_spaces:
  573. clean_text = self.clean_up_tokenization(text)
  574. return clean_text
  575. else:
  576. return text
  577. def _save_pretrained(
  578. self,
  579. save_directory: Union[str, os.PathLike],
  580. file_names: Tuple[str],
  581. legacy_format: Optional[bool] = None,
  582. filename_prefix: Optional[str] = None,
  583. ) -> Tuple[str]:
  584. """
  585. Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON
  586. file containing {config + vocab + added-tokens}.
  587. """
  588. save_directory = str(save_directory)
  589. if self.slow_tokenizer_class is None and legacy_format is True:
  590. raise ValueError(
  591. "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
  592. " might consider leaving the legacy_format at `None` or setting it to `False`."
  593. )
  594. save_slow = (
  595. (legacy_format is None or legacy_format is True)
  596. and self.slow_tokenizer_class is not None
  597. and self.can_save_slow_tokenizer
  598. )
  599. save_fast = legacy_format is None or legacy_format is False
  600. if save_slow:
  601. added_tokens_file = os.path.join(
  602. save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
  603. )
  604. # make sure to be foward compatible
  605. added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
  606. if added_vocab:
  607. with open(added_tokens_file, "w", encoding="utf-8") as f:
  608. out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
  609. f.write(out_str)
  610. vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
  611. file_names = file_names + vocab_files + (added_tokens_file,)
  612. if save_fast:
  613. tokenizer_file = os.path.join(
  614. save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE
  615. )
  616. self.backend_tokenizer.save(tokenizer_file)
  617. file_names = file_names + (tokenizer_file,)
  618. return file_names
  619. def train_new_from_iterator(
  620. self,
  621. text_iterator,
  622. vocab_size,
  623. length=None,
  624. new_special_tokens=None,
  625. special_tokens_map=None,
  626. **kwargs,
  627. ):
  628. """
  629. Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
  630. as the current one.
  631. Args:
  632. text_iterator (generator of `List[str]`):
  633. The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
  634. if you have everything in memory.
  635. vocab_size (`int`):
  636. The size of the vocabulary you want for your tokenizer.
  637. length (`int`, *optional*):
  638. The total number of sequences in the iterator. This is used to provide meaningful progress tracking
  639. new_special_tokens (list of `str` or `AddedToken`, *optional*):
  640. A list of new special tokens to add to the tokenizer you are training.
  641. special_tokens_map (`Dict[str, str]`, *optional*):
  642. If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
  643. token name to new special token name in this argument.
  644. kwargs (`Dict[str, Any]`, *optional*):
  645. Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
  646. Returns:
  647. [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
  648. `text_iterator`.
  649. """
  650. tokenizer_json = json.loads(self._tokenizer.to_str())
  651. # Remove added tokens for now (uses IDs of tokens)
  652. added_tokens = tokenizer_json.pop("added_tokens")
  653. # Remove post processor for now (uses IDs of tokens)
  654. post_processor = tokenizer_json.pop("post_processor")
  655. unk_token = None
  656. # Remove vocab
  657. if tokenizer_json["model"]["type"] == "BPE":
  658. tokenizer_json["model"]["vocab"] = {}
  659. tokenizer_json["model"]["merges"] = []
  660. elif tokenizer_json["model"]["type"] == "Unigram":
  661. if tokenizer_json["model"]["unk_id"] is not None:
  662. unk_id = tokenizer_json["model"]["unk_id"]
  663. unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
  664. if special_tokens_map is not None and unk_token in special_tokens_map:
  665. unk_token = special_tokens_map[unk_token]
  666. tokenizer_json["model"]["unk_id"] = 0
  667. tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
  668. elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
  669. tokenizer_json["model"]["vocab"] = {}
  670. else:
  671. raise ValueError(
  672. f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
  673. "only BPE, Unigram, WordLevel and WordPiece."
  674. )
  675. if (
  676. special_tokens_map is not None
  677. and "unk_token" in tokenizer_json["model"]
  678. and tokenizer_json["model"]["unk_token"] in special_tokens_map
  679. ):
  680. tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
  681. tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
  682. # Get the special tokens from the current tokenizer if none are specified.
  683. special_tokens = []
  684. for added_token in added_tokens:
  685. special = added_token.pop("special", None)
  686. _ = added_token.pop("id", None)
  687. if tokenizer_json["model"]["type"] != "Unigram" and not special:
  688. continue
  689. if special_tokens_map is not None and added_token["content"] in special_tokens_map:
  690. added_token["content"] = special_tokens_map[added_token["content"]]
  691. special_tokens.append(AddedToken(**added_token))
  692. if new_special_tokens is not None:
  693. special_tokens.extend(new_special_tokens)
  694. # Trainer needs to know the end of word / continuing subword thingies in BPE
  695. if (
  696. tokenizer_json["model"]["type"] == "BPE"
  697. and "continuing_subword_prefix" not in kwargs
  698. and tokenizer_json["model"]["continuing_subword_prefix"] is not None
  699. ):
  700. kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
  701. if (
  702. tokenizer_json["model"]["type"] == "BPE"
  703. and "end_of_word_suffix" not in kwargs
  704. and tokenizer_json["model"]["end_of_word_suffix"] is not None
  705. ):
  706. kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
  707. if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
  708. kwargs["unk_token"] = unk_token
  709. if (
  710. tokenizer_json["pre_tokenizer"] is not None
  711. and tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
  712. or tokenizer_json["pre_tokenizer"]["type"] == "Sequence"
  713. and "pretokenizers" in tokenizer_json["pre_tokenizer"]
  714. and any(
  715. pretokenizer["type"] == "ByteLevel"
  716. for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"]
  717. )
  718. ):
  719. kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
  720. trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
  721. trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
  722. tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer)
  723. if post_processor is not None:
  724. trained_tokenizer_json = json.loads(tokenizer.to_str())
  725. # Almost done, we just have to adjust the token IDs in the post processor
  726. if "special_tokens" in post_processor:
  727. for key in post_processor["special_tokens"]:
  728. tokens = post_processor["special_tokens"][key]["tokens"]
  729. if special_tokens_map is not None:
  730. tokens = [special_tokens_map.get(token, token) for token in tokens]
  731. post_processor["special_tokens"][key]["tokens"] = tokens
  732. for token in tokens:
  733. token_id = tokenizer.token_to_id(token)
  734. if token_id is None:
  735. raise ValueError(
  736. "Attempted to set a token in the post processor that does not exist in the mapping"
  737. )
  738. post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
  739. for special_token in ["cls", "sep"]:
  740. if special_token in post_processor:
  741. token, _ = post_processor[special_token]
  742. if special_tokens_map is not None and token in special_tokens_map:
  743. token = special_tokens_map[token]
  744. token_id = tokenizer.token_to_id(token)
  745. if token_id is None:
  746. raise ValueError(
  747. "Attempted to set a token in the post processor that does not exist in the mapping"
  748. )
  749. post_processor[special_token] = [token, token_id]
  750. trained_tokenizer_json["post_processor"] = post_processor
  751. tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
  752. kwargs = self.init_kwargs.copy()
  753. # Map pad/cls/mask token at the Transformers level
  754. special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
  755. special_tokens_list.remove("additional_special_tokens")
  756. for token in special_tokens_list:
  757. # Get the private one to avoid unnecessary warnings.
  758. if getattr(self, f"_{token}") is not None:
  759. special_token = getattr(self, token)
  760. if special_tokens_map is not None and special_token in special_tokens_map:
  761. special_token = special_tokens_map[special_token]
  762. special_token_full = getattr(self, f"_{token}")
  763. if isinstance(special_token_full, AddedToken):
  764. # Create an added token with the same parameters except the content
  765. kwargs[token] = AddedToken(
  766. special_token,
  767. single_word=special_token_full.single_word,
  768. lstrip=special_token_full.lstrip,
  769. rstrip=special_token_full.rstrip,
  770. normalized=special_token_full.normalized,
  771. special=True,
  772. )
  773. else:
  774. kwargs[token] = special_token
  775. additional_special_tokens = self.additional_special_tokens
  776. if new_special_tokens is not None:
  777. additional_special_tokens.extend(new_special_tokens)
  778. if len(additional_special_tokens) > 0:
  779. kwargs["additional_special_tokens"] = additional_special_tokens
  780. return self.__class__(tokenizer_object=tokenizer, **kwargs)