tokenization_utils.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. # coding=utf-8
  2. # Copyright 2020 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
  17. tokenization_utils_fast.py
  18. """
  19. import bisect
  20. import itertools
  21. import re
  22. import unicodedata
  23. from collections import OrderedDict
  24. from typing import Any, Dict, List, Optional, Tuple, Union, overload
  25. from .tokenization_utils_base import (
  26. ENCODE_KWARGS_DOCSTRING,
  27. ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
  28. INIT_TOKENIZER_DOCSTRING,
  29. AddedToken,
  30. BatchEncoding,
  31. EncodedInput,
  32. EncodedInputPair,
  33. PreTokenizedInput,
  34. PreTokenizedInputPair,
  35. PreTrainedTokenizerBase,
  36. TextInput,
  37. TextInputPair,
  38. TruncationStrategy,
  39. )
  40. from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging
  41. logger = logging.get_logger(__name__)
  42. # Slow tokenizers are saved in a vocabulary plus three separated files
  43. SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
  44. ADDED_TOKENS_FILE = "added_tokens.json"
  45. TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
  46. class Trie:
  47. """
  48. Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
  49. Loose reference https://en.wikipedia.org/wiki/Trie
  50. """
  51. def __init__(self, *args):
  52. self.data = {}
  53. self._tokens = set()
  54. self._termination_char = ""
  55. self.update(*args)
  56. def update(self, *args):
  57. """
  58. Updates the Trie with new tokens provided as arguments.
  59. Args:
  60. *args: Variable number of words to be added to the Trie.
  61. """
  62. for token in tuple(*args):
  63. self.add(token)
  64. def add(self, word: str):
  65. """
  66. Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
  67. The special key `""` in `self._termination_char` is used to represent termination.
  68. This function is idempotent, adding twice the same word will leave the trie unchanged
  69. Example:
  70. ```python
  71. >>> trie = Trie()
  72. >>> trie.add("Hello 友達")
  73. >>> trie.data
  74. {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
  75. >>> trie.add("Hello")
  76. >>> trie.data
  77. {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
  78. ```
  79. """
  80. if not word:
  81. # Prevent empty string
  82. return
  83. self._tokens.add(word)
  84. ref = self.data
  85. for char in word:
  86. ref[char] = ref.setdefault(char, {})
  87. ref = ref[char]
  88. ref[self._termination_char] = 1
  89. def split(self, text: str) -> List[str]:
  90. """
  91. Will look for the words added to the trie within `text`. Output is the original string splitted along the
  92. boundaries of the words found.
  93. This trie will match the longest possible word first !
  94. Example:
  95. ```python
  96. >>> trie = Trie()
  97. >>> trie.split("[CLS] This is a extra_id_100")
  98. ["[CLS] This is a extra_id_100"]
  99. >>> trie.add("[CLS]")
  100. >>> trie.add("extra_id_1")
  101. >>> trie.add("extra_id_100")
  102. >>> trie.split("[CLS] This is a extra_id_100")
  103. ["[CLS]", " This is a ", "extra_id_100"]
  104. ```
  105. """
  106. # indexes are counted left of the chars index.
  107. # "hello", index 0, is left of h, index 1 is between h and e.
  108. # index 5 is right of the "o".
  109. # States are going to capture every possible start (indexes as above)
  110. # as keys, and have as values, a pointer to the position in the trie
  111. # where we're at. This is a partial match for now.
  112. # This enables to keep track of multiple matches while we're iterating
  113. # the string
  114. # If the trie contains, "blowing", and "lower" and we encounter the
  115. # string "blower", we need to split into ["b", "lower"].
  116. # This is where we need to keep track of multiple possible starts.
  117. states = OrderedDict()
  118. # This will contain every indices where we need
  119. # to cut.
  120. # We force to cut at offset 0 and len(text) (added later)
  121. offsets = [0]
  122. # This is used by the lookahead which needs to skip over
  123. # some text where the full match exceeded the place in the initial
  124. # for loop
  125. skip = 0
  126. # Main loop, Giving this algorithm O(n) complexity
  127. for current, current_char in enumerate(text):
  128. if skip and current < skip:
  129. # Prevents the lookahead for matching twice
  130. # like extra_id_100 and id_100
  131. continue
  132. # This will track every state
  133. # that stop matching, we need to stop tracking them.
  134. # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
  135. # fail on "b", we need to remove 0 from the valid states.
  136. to_remove = set()
  137. # Whenever we found a match, we need to drop everything
  138. # this is a greedy algorithm, it will match on the first found token
  139. reset = False
  140. # In this case, we already have partial matches (But unfinished)
  141. for start, trie_pointer in states.items():
  142. if "" in trie_pointer:
  143. # This is a final match, we need to reset and
  144. # store the results in `offsets`.
  145. # Lookahead to match longest first
  146. # Important in case of extra_id_1 vs extra_id_100
  147. # Here we are also actively looking for other earlier partial
  148. # matches
  149. # "[CLS]", "L", we need to match CLS even if L is special
  150. for lookstart, looktrie_pointer in states.items():
  151. if lookstart > start:
  152. # This partial match is later, we can stop looking
  153. break
  154. elif lookstart < start:
  155. # This partial match is earlier, the trie pointer
  156. # was already updated, so index is + 1
  157. lookahead_index = current + 1
  158. end = current + 1
  159. else:
  160. # Here lookstart == start and
  161. # looktrie_pointer == trie_pointer
  162. # It wasn't updated yet so indices are current ones
  163. lookahead_index = current
  164. end = current
  165. next_char = text[lookahead_index] if lookahead_index < len(text) else None
  166. if "" in looktrie_pointer:
  167. start = lookstart
  168. end = lookahead_index
  169. skip = lookahead_index
  170. while next_char in looktrie_pointer:
  171. looktrie_pointer = looktrie_pointer[next_char]
  172. lookahead_index += 1
  173. if "" in looktrie_pointer:
  174. start = lookstart
  175. end = lookahead_index
  176. skip = lookahead_index
  177. if lookahead_index == len(text):
  178. # End of string
  179. break
  180. next_char = text[lookahead_index]
  181. # End lookahead
  182. # Storing and resetting
  183. offsets.append(start)
  184. offsets.append(end)
  185. reset = True
  186. break
  187. elif current_char in trie_pointer:
  188. # The current character being looked at has a match within the trie
  189. # update the pointer (it will be stored back into states later).
  190. trie_pointer = trie_pointer[current_char]
  191. # Storing back the new pointer into the states.
  192. # Partial matches got longer by one.
  193. states[start] = trie_pointer
  194. else:
  195. # The new character has not match in the trie, we need
  196. # to stop keeping track of this partial match.
  197. # We can't do it directly within the loop because of how
  198. # python iteration works
  199. to_remove.add(start)
  200. # Either clearing the full start (we found a real match)
  201. # Or clearing only the partial matches that didn't work.
  202. if reset:
  203. states = {}
  204. else:
  205. for start in to_remove:
  206. del states[start]
  207. # If this character is a starting character within the trie
  208. # start keeping track of this partial match.
  209. if current >= skip and current_char in self.data:
  210. states[current] = self.data[current_char]
  211. # We have a cut at the end with states.
  212. for start, trie_pointer in states.items():
  213. if "" in trie_pointer:
  214. # This is a final match, we need to reset and
  215. # store the results in `offsets`.
  216. end = len(text)
  217. offsets.append(start)
  218. offsets.append(end)
  219. # Longest cut is always the one with lower start so the first
  220. # item so we need to break.
  221. break
  222. return self.cut_text(text, offsets)
  223. def cut_text(self, text, offsets):
  224. # We have all the offsets now, we just need to do the actual splitting.
  225. # We need to eventually add the first part of the string and the eventual
  226. # last part.
  227. offsets.append(len(text))
  228. tokens = []
  229. start = 0
  230. for end in offsets:
  231. if start > end:
  232. logger.error(
  233. "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
  234. " anyway."
  235. )
  236. continue
  237. elif start == end:
  238. # This might happen if there's a match at index 0
  239. # we're also preventing zero-width cuts in case of two
  240. # consecutive matches
  241. continue
  242. tokens.append(text[start:end])
  243. start = end
  244. return tokens
  245. class ExtensionsTrie(Trie):
  246. def __init__(self, *args):
  247. super().__init__(*args)
  248. def extensions(self, prefix: str):
  249. """
  250. Generates all extensions of a given prefix token in the Trie.
  251. Example:
  252. ```python
  253. >>> trie = Trie()
  254. >>> trie.add("apple")
  255. >>> trie.add("app")
  256. >>> trie.add("application")
  257. >>> trie.extensions("app")
  258. ['app', 'apple', 'application']
  259. ```
  260. """
  261. prefix_node = self._get_node(prefix)
  262. ret = self._collect_tokens(prefix_node)
  263. return [prefix + token for token in ret]
  264. def _get_node(self, token: str) -> dict:
  265. """
  266. Retrieves the node corresponding to the given token in the Trie.
  267. Args:
  268. token (str): The token for which the corresponding node needs to be retrieved.
  269. Returns:
  270. dict: The node in the Trie corresponding to the given token.
  271. """
  272. node = self.data
  273. for char in token:
  274. if char not in node:
  275. break
  276. node = node[char]
  277. return node
  278. def _collect_tokens(self, node: dict) -> list:
  279. """
  280. Generates all tokens in the Trie starting from a given node.
  281. Args:
  282. node (dict): The node in the Trie from which tokens need to be generated.
  283. Returns:
  284. list: List of tokens generated from the given node.
  285. """
  286. tokens = [self._termination_char] if self._termination_char in node else []
  287. for token, subtrie_head in node.items():
  288. if token != self._termination_char:
  289. subtokens = self._collect_tokens(subtrie_head)
  290. tokens.extend([token + subtoken for subtoken in subtokens])
  291. return tokens
  292. def _is_whitespace(char):
  293. """Checks whether `char` is a whitespace character."""
  294. # \t, \n, and \r are technically control characters but we treat them
  295. # as whitespace since they are generally considered as such.
  296. if char == " " or char == "\t" or char == "\n" or char == "\r":
  297. return True
  298. cat = unicodedata.category(char)
  299. if cat == "Zs":
  300. return True
  301. return False
  302. def _is_control(char):
  303. """Checks whether `char` is a control character."""
  304. # These are technically control characters but we count them as whitespace
  305. # characters.
  306. if char == "\t" or char == "\n" or char == "\r":
  307. return False
  308. cat = unicodedata.category(char)
  309. if cat.startswith("C"):
  310. return True
  311. return False
  312. def _is_punctuation(char):
  313. """Checks whether `char` is a punctuation character."""
  314. cp = ord(char)
  315. # We treat all non-letter/number ASCII as punctuation.
  316. # Characters such as "^", "$", and "`" are not in the Unicode
  317. # Punctuation class but we treat them as punctuation anyways, for
  318. # consistency.
  319. if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
  320. return True
  321. cat = unicodedata.category(char)
  322. if cat.startswith("P"):
  323. return True
  324. return False
  325. def _is_end_of_word(text):
  326. """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
  327. last_char = text[-1]
  328. return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
  329. def _is_start_of_word(text):
  330. """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
  331. first_char = text[0]
  332. return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
  333. def _insert_one_token_to_ordered_list(token_list: List[str], new_token: str):
  334. """
  335. Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
  336. """
  337. insertion_idx = bisect.bisect_left(token_list, new_token)
  338. # Checks if new_token is already in the ordered token_list
  339. if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
  340. # new_token is in token_list, don't add
  341. return
  342. else:
  343. token_list.insert(insertion_idx, new_token)
  344. @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
  345. class PreTrainedTokenizer(PreTrainedTokenizerBase):
  346. """
  347. Base class for all slow tokenizers.
  348. Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
  349. Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
  350. pretrained tokenizers as well as adding tokens to the vocabulary.
  351. This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
  352. specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
  353. """
  354. def __init__(self, **kwargs):
  355. # 1. Init the parent class
  356. self.tokens_trie = Trie()
  357. # 2. init `_added_tokens_decoder` if child class did not
  358. if not hasattr(self, "_added_tokens_decoder"):
  359. self._added_tokens_decoder: Dict[int, AddedToken] = {}
  360. # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
  361. self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
  362. self._added_tokens_encoder: Dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()}
  363. # 4 init the parent class
  364. super().__init__(**kwargs)
  365. # 4. If some of the special tokens are not part of the vocab, we add them, at the end.
  366. # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`
  367. self._add_tokens(
  368. [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder],
  369. special_tokens=True,
  370. )
  371. self._decode_use_source_tokenizer = False
  372. @property
  373. def is_fast(self) -> bool:
  374. return False
  375. @property
  376. def vocab_size(self) -> int:
  377. """
  378. `int`: Size of the base vocabulary (without the added tokens).
  379. """
  380. raise NotImplementedError
  381. @property
  382. def added_tokens_encoder(self) -> Dict[str, int]:
  383. """
  384. Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
  385. optimisation in `self._added_tokens_encoder` for the slow tokenizers.
  386. """
  387. return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
  388. @property
  389. def added_tokens_decoder(self) -> Dict[int, AddedToken]:
  390. """
  391. Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
  392. Returns:
  393. `Dict[str, int]`: The added tokens.
  394. """
  395. return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0]))
  396. @added_tokens_decoder.setter
  397. def added_tokens_decoder(self, value: Dict[int, Union[AddedToken, str]]) -> Dict[int, AddedToken]:
  398. # Always raise an error if string because users should define the behavior
  399. for index, token in value.items():
  400. if not isinstance(token, (str, AddedToken)) or not isinstance(index, int):
  401. raise TypeError(
  402. f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}"
  403. )
  404. self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
  405. self._added_tokens_encoder[str(token)] = index
  406. self._update_total_vocab_size()
  407. def get_added_vocab(self) -> Dict[str, int]:
  408. """
  409. Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
  410. the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
  411. something we should change.
  412. Returns:
  413. `Dict[str, int]`: The added tokens.
  414. """
  415. return self._added_tokens_encoder
  416. def __len__(self):
  417. """
  418. Size of the full vocabulary with the added tokens.
  419. """
  420. return self.total_vocab_size
  421. def _update_total_vocab_size(self):
  422. """
  423. Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
  424. otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
  425. is only updated when adding tokens.
  426. """
  427. self.total_vocab_size = len(self.get_vocab())
  428. def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
  429. """
  430. Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
  431. it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
  432. vocab which is why they have to be handled specifically.
  433. Args:
  434. new_tokens (`List[str]`or `List[tokenizers.AddedToken]`):
  435. Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
  436. (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
  437. of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
  438. stripping and normalization of this token. This is NOT possible in `tokenizers`.
  439. special_tokens (`bool`, *optional*, defaults to `False`):
  440. Whether or not the tokens should be added as special tokens.
  441. Returns:
  442. `int`: The number of tokens actually added to the vocabulary.
  443. Examples:
  444. ```python
  445. # Let's see how to increase the vocabulary of Bert model and tokenizer
  446. tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
  447. model = BertModel.from_pretrained("google-bert/bert-base-uncased")
  448. num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
  449. print("We have added", num_added_toks, "tokens")
  450. # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
  451. model.resize_token_embeddings(len(tokenizer))
  452. ```"""
  453. added_tokens = 0
  454. if new_tokens is None:
  455. return added_tokens
  456. # TODO this is fairly slow to improve!
  457. current_vocab = self.get_vocab().copy()
  458. new_idx = len(current_vocab) # only call this once, len gives the last index + 1
  459. for token in new_tokens:
  460. if not isinstance(token, (str, AddedToken)):
  461. raise TypeError(f"Token {token} is not a string but a {type(token)}.")
  462. if str(token) == "":
  463. continue
  464. if isinstance(token, str):
  465. if token in self._added_tokens_encoder:
  466. continue
  467. else:
  468. # very important for fast and slow equivalence!
  469. is_special = token in self.all_special_tokens or special_tokens
  470. token = AddedToken(
  471. token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special
  472. )
  473. elif special_tokens:
  474. # doing token.special=True changes the normalization! will fix in rust
  475. # this is important and the only reason why the AddedTokens in each class are normalized by default
  476. token.__setstate__({"special": True, "normalized": token.normalized})
  477. if token in self._added_tokens_decoder:
  478. continue
  479. if not token.special and token.normalized and getattr(self, "do_lower_case", False):
  480. # Normalize if requested
  481. token.content = token.content.lower()
  482. if token.content not in current_vocab:
  483. token_index = new_idx + added_tokens
  484. current_vocab[token.content] = token_index
  485. added_tokens += 1
  486. else:
  487. token_index = current_vocab[token.content]
  488. if token.special and str(token) not in self.all_special_tokens:
  489. self._additional_special_tokens.append(token)
  490. # the setter automatically updates the reverse map
  491. self._added_tokens_decoder[token_index] = token
  492. self._added_tokens_encoder[token.content] = token_index
  493. if self.verbose:
  494. logger.info(f"Adding {token} to the vocabulary")
  495. self._update_trie()
  496. self._update_total_vocab_size()
  497. return added_tokens
  498. def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
  499. for token in self._added_tokens_decoder.values():
  500. if token not in self.tokens_trie._tokens:
  501. self.tokens_trie.add(token.content)
  502. for token in unique_no_split_tokens:
  503. if token not in self.tokens_trie._tokens:
  504. self.tokens_trie.add(token)
  505. def num_special_tokens_to_add(self, pair: bool = False) -> int:
  506. """
  507. Returns the number of added tokens when encoding a sequence with special tokens.
  508. <Tip>
  509. This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
  510. this inside your training loop.
  511. </Tip>
  512. Args:
  513. pair (`bool`, *optional*, defaults to `False`):
  514. Whether the number of added tokens should be computed in the case of a sequence pair or a single
  515. sequence.
  516. Returns:
  517. `int`: Number of special tokens added to sequences.
  518. """
  519. token_ids_0 = []
  520. token_ids_1 = []
  521. return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
  522. def tokenize(self, text: TextInput, **kwargs) -> List[str]:
  523. """
  524. Converts a string into a sequence of tokens, using the tokenizer.
  525. Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
  526. (BPE/SentencePieces/WordPieces). Takes care of added tokens.
  527. Args:
  528. text (`str`):
  529. The sequence to be encoded.
  530. **kwargs (additional keyword arguments):
  531. Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
  532. Returns:
  533. `List[str]`: The list of tokens.
  534. """
  535. split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
  536. text, kwargs = self.prepare_for_tokenization(text, **kwargs)
  537. if kwargs:
  538. logger.warning(f"Keyword arguments {kwargs} not recognized.")
  539. if hasattr(self, "do_lower_case") and self.do_lower_case:
  540. # convert non-special tokens to lowercase. Might be super slow as well?
  541. escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
  542. escaped_special_toks += [
  543. re.escape(s_tok.content)
  544. for s_tok in (self._added_tokens_decoder.values())
  545. if not s_tok.special and s_tok.normalized
  546. ]
  547. pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
  548. text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
  549. if split_special_tokens:
  550. no_split_token = []
  551. tokens = [text]
  552. else:
  553. no_split_token = self._added_tokens_encoder.keys() # don't split on any of the added tokens
  554. # "This is something<special_token_1> else"
  555. tokens = self.tokens_trie.split(text)
  556. # ["This is something", "<special_token_1>", " else"]
  557. for i, token in enumerate(tokens):
  558. if token in no_split_token:
  559. tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None)
  560. left = tokens[i - 1] if i > 0 else None
  561. right = tokens[i + 1] if i < len(tokens) - 1 else None
  562. if isinstance(tok_extended, AddedToken):
  563. if tok_extended.rstrip and right:
  564. # A bit counter-intuitive but we strip the left of the string
  565. # since tok_extended.rstrip means the special token is eating all white spaces on its right
  566. tokens[i + 1] = right.lstrip()
  567. # Strip white spaces on the left
  568. if tok_extended.lstrip and left:
  569. tokens[i - 1] = left.rstrip() # Opposite here
  570. if tok_extended.single_word and left and left[-1] != " ":
  571. tokens[i - 1] += token
  572. tokens[i] = ""
  573. elif tok_extended.single_word and right and right[0] != " ":
  574. tokens[i + 1] = token + tokens[i + 1]
  575. tokens[i] = ""
  576. else:
  577. raise ValueError(
  578. f"{tok_extended} cannot be tokenized because it was not properly added"
  579. f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}"
  580. )
  581. # ["This is something", "<special_token_1>", "else"]
  582. tokenized_text = []
  583. for token in tokens:
  584. # Need to skip eventual empty (fully stripped) tokens
  585. if not token:
  586. continue
  587. if token in no_split_token:
  588. tokenized_text.append(token)
  589. else:
  590. tokenized_text.extend(self._tokenize(token))
  591. # ["This", " is", " something", "<special_token_1>", "else"]
  592. return tokenized_text
  593. def _tokenize(self, text, **kwargs):
  594. """
  595. Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
  596. vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
  597. Do NOT take care of added tokens.
  598. """
  599. raise NotImplementedError
  600. def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
  601. """
  602. Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
  603. vocabulary.
  604. Args:
  605. tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
  606. Returns:
  607. `int` or `List[int]`: The token id or list of token ids.
  608. """
  609. if tokens is None:
  610. return None
  611. if isinstance(tokens, str):
  612. return self._convert_token_to_id_with_added_voc(tokens)
  613. ids = []
  614. for token in tokens:
  615. ids.append(self._convert_token_to_id_with_added_voc(token))
  616. return ids
  617. def _convert_token_to_id_with_added_voc(self, token):
  618. if token is None:
  619. return None
  620. if token in self._added_tokens_encoder:
  621. return self._added_tokens_encoder[token]
  622. return self._convert_token_to_id(token)
  623. def _convert_token_to_id(self, token):
  624. raise NotImplementedError
  625. def _encode_plus(
  626. self,
  627. text: Union[TextInput, PreTokenizedInput, EncodedInput],
  628. text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
  629. add_special_tokens: bool = True,
  630. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  631. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  632. max_length: Optional[int] = None,
  633. stride: int = 0,
  634. is_split_into_words: bool = False,
  635. pad_to_multiple_of: Optional[int] = None,
  636. padding_side: Optional[bool] = None,
  637. return_tensors: Optional[Union[str, TensorType]] = None,
  638. return_token_type_ids: Optional[bool] = None,
  639. return_attention_mask: Optional[bool] = None,
  640. return_overflowing_tokens: bool = False,
  641. return_special_tokens_mask: bool = False,
  642. return_offsets_mapping: bool = False,
  643. return_length: bool = False,
  644. verbose: bool = True,
  645. **kwargs,
  646. ) -> BatchEncoding:
  647. def get_input_ids(text):
  648. if isinstance(text, str):
  649. tokens = self.tokenize(text, **kwargs)
  650. return self.convert_tokens_to_ids(tokens)
  651. elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
  652. if is_split_into_words:
  653. tokens = list(
  654. itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
  655. )
  656. return self.convert_tokens_to_ids(tokens)
  657. else:
  658. return self.convert_tokens_to_ids(text)
  659. elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
  660. return text
  661. else:
  662. if is_split_into_words:
  663. raise ValueError(
  664. f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
  665. " `is_split_into_words=True`."
  666. )
  667. else:
  668. raise ValueError(
  669. f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
  670. " integers."
  671. )
  672. if return_offsets_mapping:
  673. raise NotImplementedError(
  674. "return_offset_mapping is not available when using Python tokenizers. "
  675. "To use this feature, change your tokenizer to one deriving from "
  676. "transformers.PreTrainedTokenizerFast. "
  677. "More information on available tokenizers at "
  678. "https://github.com/huggingface/transformers/pull/2674"
  679. )
  680. first_ids = get_input_ids(text)
  681. second_ids = get_input_ids(text_pair) if text_pair is not None else None
  682. return self.prepare_for_model(
  683. first_ids,
  684. pair_ids=second_ids,
  685. add_special_tokens=add_special_tokens,
  686. padding=padding_strategy.value,
  687. truncation=truncation_strategy.value,
  688. max_length=max_length,
  689. stride=stride,
  690. pad_to_multiple_of=pad_to_multiple_of,
  691. padding_side=padding_side,
  692. return_tensors=return_tensors,
  693. prepend_batch_axis=True,
  694. return_attention_mask=return_attention_mask,
  695. return_token_type_ids=return_token_type_ids,
  696. return_overflowing_tokens=return_overflowing_tokens,
  697. return_special_tokens_mask=return_special_tokens_mask,
  698. return_length=return_length,
  699. verbose=verbose,
  700. )
  701. def _batch_encode_plus(
  702. self,
  703. batch_text_or_text_pairs: Union[
  704. List[TextInput],
  705. List[TextInputPair],
  706. List[PreTokenizedInput],
  707. List[PreTokenizedInputPair],
  708. List[EncodedInput],
  709. List[EncodedInputPair],
  710. ],
  711. add_special_tokens: bool = True,
  712. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  713. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  714. max_length: Optional[int] = None,
  715. stride: int = 0,
  716. is_split_into_words: bool = False,
  717. pad_to_multiple_of: Optional[int] = None,
  718. padding_side: Optional[bool] = None,
  719. return_tensors: Optional[Union[str, TensorType]] = None,
  720. return_token_type_ids: Optional[bool] = None,
  721. return_attention_mask: Optional[bool] = None,
  722. return_overflowing_tokens: bool = False,
  723. return_special_tokens_mask: bool = False,
  724. return_offsets_mapping: bool = False,
  725. return_length: bool = False,
  726. verbose: bool = True,
  727. split_special_tokens: bool = False,
  728. **kwargs,
  729. ) -> BatchEncoding:
  730. def get_input_ids(text):
  731. if isinstance(text, str):
  732. tokens = self.tokenize(text, **kwargs)
  733. return self.convert_tokens_to_ids(tokens)
  734. elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
  735. if is_split_into_words:
  736. tokens = list(
  737. itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
  738. )
  739. return self.convert_tokens_to_ids(tokens)
  740. else:
  741. return self.convert_tokens_to_ids(text)
  742. elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
  743. return text
  744. else:
  745. raise ValueError(
  746. "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
  747. )
  748. if return_offsets_mapping:
  749. raise NotImplementedError(
  750. "return_offset_mapping is not available when using Python tokenizers. "
  751. "To use this feature, change your tokenizer to one deriving from "
  752. "transformers.PreTrainedTokenizerFast."
  753. )
  754. input_ids = []
  755. for ids_or_pair_ids in batch_text_or_text_pairs:
  756. if not isinstance(ids_or_pair_ids, (list, tuple)):
  757. ids, pair_ids = ids_or_pair_ids, None
  758. elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
  759. ids, pair_ids = ids_or_pair_ids, None
  760. else:
  761. ids, pair_ids = ids_or_pair_ids
  762. first_ids = get_input_ids(ids)
  763. second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
  764. input_ids.append((first_ids, second_ids))
  765. batch_outputs = self._batch_prepare_for_model(
  766. input_ids,
  767. add_special_tokens=add_special_tokens,
  768. padding_strategy=padding_strategy,
  769. truncation_strategy=truncation_strategy,
  770. max_length=max_length,
  771. stride=stride,
  772. pad_to_multiple_of=pad_to_multiple_of,
  773. padding_side=padding_side,
  774. return_attention_mask=return_attention_mask,
  775. return_token_type_ids=return_token_type_ids,
  776. return_overflowing_tokens=return_overflowing_tokens,
  777. return_special_tokens_mask=return_special_tokens_mask,
  778. return_length=return_length,
  779. return_tensors=return_tensors,
  780. verbose=verbose,
  781. split_special_tokens=split_special_tokens,
  782. )
  783. return BatchEncoding(batch_outputs)
  784. @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
  785. def _batch_prepare_for_model(
  786. self,
  787. batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
  788. add_special_tokens: bool = True,
  789. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  790. truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
  791. max_length: Optional[int] = None,
  792. stride: int = 0,
  793. pad_to_multiple_of: Optional[int] = None,
  794. padding_side: Optional[bool] = None,
  795. return_tensors: Optional[str] = None,
  796. return_token_type_ids: Optional[bool] = None,
  797. return_attention_mask: Optional[bool] = None,
  798. return_overflowing_tokens: bool = False,
  799. return_special_tokens_mask: bool = False,
  800. return_length: bool = False,
  801. verbose: bool = True,
  802. split_special_tokens: bool = False,
  803. ) -> BatchEncoding:
  804. """
  805. Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
  806. adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
  807. manages a moving window (with user defined stride) for overflowing tokens
  808. Args:
  809. batch_ids_pairs: list of tokenized input ids or input ids pairs
  810. """
  811. batch_outputs = {}
  812. for first_ids, second_ids in batch_ids_pairs:
  813. outputs = self.prepare_for_model(
  814. first_ids,
  815. second_ids,
  816. add_special_tokens=add_special_tokens,
  817. padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
  818. truncation=truncation_strategy.value,
  819. max_length=max_length,
  820. stride=stride,
  821. pad_to_multiple_of=None, # we pad in batch afterward
  822. padding_side=None, # we pad in batch afterward
  823. return_attention_mask=False, # we pad in batch afterward
  824. return_token_type_ids=return_token_type_ids,
  825. return_overflowing_tokens=return_overflowing_tokens,
  826. return_special_tokens_mask=return_special_tokens_mask,
  827. return_length=return_length,
  828. return_tensors=None, # We convert the whole batch to tensors at the end
  829. prepend_batch_axis=False,
  830. verbose=verbose,
  831. split_special_tokens=split_special_tokens,
  832. )
  833. for key, value in outputs.items():
  834. if key not in batch_outputs:
  835. batch_outputs[key] = []
  836. batch_outputs[key].append(value)
  837. batch_outputs = self.pad(
  838. batch_outputs,
  839. padding=padding_strategy.value,
  840. max_length=max_length,
  841. pad_to_multiple_of=pad_to_multiple_of,
  842. padding_side=padding_side,
  843. return_attention_mask=return_attention_mask,
  844. )
  845. batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
  846. return batch_outputs
  847. def prepare_for_tokenization(
  848. self, text: str, is_split_into_words: bool = False, **kwargs
  849. ) -> Tuple[str, Dict[str, Any]]:
  850. """
  851. Performs any necessary transformations before tokenization.
  852. This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
  853. `kwargs` at the end of the encoding process to be sure all the arguments have been used.
  854. Args:
  855. text (`str`):
  856. The text to prepare.
  857. is_split_into_words (`bool`, *optional*, defaults to `False`):
  858. Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
  859. tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
  860. which it will tokenize. This is useful for NER or token classification.
  861. kwargs (`Dict[str, Any]`, *optional*):
  862. Keyword arguments to use for the tokenization.
  863. Returns:
  864. `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
  865. """
  866. return (text, kwargs)
  867. def get_special_tokens_mask(
  868. self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
  869. ) -> List[int]:
  870. """
  871. Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
  872. special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
  873. Args:
  874. token_ids_0 (`List[int]`):
  875. List of ids of the first sequence.
  876. token_ids_1 (`List[int]`, *optional*):
  877. List of ids of the second sequence.
  878. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  879. Whether or not the token list is already formatted with special tokens for the model.
  880. Returns:
  881. A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  882. """
  883. if already_has_special_tokens:
  884. if token_ids_1 is not None:
  885. raise ValueError(
  886. "You should not supply a second sequence if the provided sequence of "
  887. "ids is already formatted with special tokens for the model."
  888. )
  889. return super().get_special_tokens_mask(
  890. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  891. )
  892. return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
  893. @overload
  894. def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
  895. @overload
  896. def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: ...
  897. def convert_ids_to_tokens(
  898. self, ids: Union[int, List[int]], skip_special_tokens: bool = False
  899. ) -> Union[str, List[str]]:
  900. """
  901. Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
  902. added tokens.
  903. Args:
  904. ids (`int` or `List[int]`):
  905. The token id (or token ids) to convert to tokens.
  906. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  907. Whether or not to remove special tokens in the decoding.
  908. Returns:
  909. `str` or `List[str]`: The decoded token(s).
  910. """
  911. if isinstance(ids, int):
  912. if ids in self._added_tokens_decoder:
  913. return self._added_tokens_decoder[ids].content
  914. else:
  915. return self._convert_id_to_token(ids)
  916. tokens = []
  917. for index in ids:
  918. index = int(index)
  919. if skip_special_tokens and index in self.all_special_ids:
  920. continue
  921. if index in self._added_tokens_decoder:
  922. tokens.append(self._added_tokens_decoder[index].content)
  923. else:
  924. tokens.append(self._convert_id_to_token(index))
  925. return tokens
  926. def _convert_id_to_token(self, index: int) -> str:
  927. raise NotImplementedError
  928. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  929. return " ".join(tokens)
  930. def _decode(
  931. self,
  932. token_ids: Union[int, List[int]],
  933. skip_special_tokens: bool = False,
  934. clean_up_tokenization_spaces: bool = None,
  935. spaces_between_special_tokens: bool = True,
  936. **kwargs,
  937. ) -> str:
  938. self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
  939. filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
  940. # If given is a single id, prevents splitting the string in upcoming loop
  941. if isinstance(filtered_tokens, str):
  942. filtered_tokens = [filtered_tokens]
  943. legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
  944. token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
  945. }
  946. # To avoid mixing byte-level and unicode for byte-level BPT
  947. # we need to build string separately for added tokens and byte-level tokens
  948. # cf. https://github.com/huggingface/transformers/issues/1133
  949. sub_texts = []
  950. current_sub_text = []
  951. # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
  952. for token in filtered_tokens:
  953. if skip_special_tokens and token in self.all_special_tokens:
  954. continue
  955. if token in legacy_added_tokens:
  956. if current_sub_text:
  957. string = self.convert_tokens_to_string(current_sub_text)
  958. if len(string) > 0:
  959. sub_texts.append(string)
  960. current_sub_text = []
  961. sub_texts.append(token)
  962. else:
  963. current_sub_text.append(token)
  964. if current_sub_text:
  965. sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  966. if spaces_between_special_tokens:
  967. text = " ".join(sub_texts)
  968. else:
  969. text = "".join(sub_texts)
  970. clean_up_tokenization_spaces = (
  971. clean_up_tokenization_spaces
  972. if clean_up_tokenization_spaces is not None
  973. else self.clean_up_tokenization_spaces
  974. )
  975. if clean_up_tokenization_spaces:
  976. clean_text = self.clean_up_tokenization(text)
  977. return clean_text
  978. else:
  979. return text