convert_slow_tokenizer.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Utilities to convert slow tokenizers in their fast tokenizers counterparts.
  17. All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
  18. allow to make our dependency on SentencePiece optional.
  19. """
  20. import warnings
  21. from typing import Dict, List, Tuple
  22. from packaging import version
  23. from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
  24. from tokenizers.models import BPE, Unigram, WordPiece
  25. from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
  26. from .utils.import_utils import PROTOBUF_IMPORT_ERROR
  27. logger = logging.get_logger(__name__)
  28. def import_protobuf(error_message=""):
  29. if is_sentencepiece_available():
  30. from sentencepiece import sentencepiece_model_pb2
  31. return sentencepiece_model_pb2
  32. if is_protobuf_available():
  33. import google.protobuf
  34. if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
  35. from transformers.utils import sentencepiece_model_pb2
  36. else:
  37. from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
  38. return sentencepiece_model_pb2
  39. else:
  40. raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
  41. def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
  42. if add_prefix_space:
  43. prepend_scheme = "always"
  44. if not getattr(original_tokenizer, "legacy", True):
  45. prepend_scheme = "first"
  46. else:
  47. prepend_scheme = "never"
  48. return prepend_scheme
  49. def generate_merges(vocab, vocab_scores):
  50. reverse = vocab_scores is not None
  51. vocab_scores = dict(vocab_scores) if reverse else vocab
  52. merges = []
  53. for merge, piece_score in vocab_scores.items():
  54. local = []
  55. for index in range(1, len(merge)):
  56. piece_l, piece_r = merge[:index], merge[index:]
  57. if piece_l in vocab and piece_r in vocab:
  58. local.append((piece_l, piece_r, piece_score))
  59. local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
  60. merges.extend(local)
  61. merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
  62. merges = [(val[0], val[1]) for val in merges]
  63. return merges
  64. class SentencePieceExtractor:
  65. """
  66. Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
  67. """
  68. def __init__(self, model: str):
  69. requires_backends(self, "sentencepiece")
  70. from sentencepiece import SentencePieceProcessor
  71. self.sp = SentencePieceProcessor()
  72. self.sp.Load(model)
  73. def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
  74. """
  75. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  76. order the merges with respect to the piece scores instead.
  77. """
  78. sp = self.sp
  79. vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  80. merges = generate_merges(vocab, vocab_scores)
  81. return vocab, merges
  82. class GemmaSentencePieceExtractor(SentencePieceExtractor):
  83. def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
  84. """
  85. By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
  86. order the merges with respect to the piece scores instead.
  87. """
  88. sp = self.sp
  89. vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
  90. # there is a missing token in the vocab. We have to do this to support merges
  91. # "<0x09>" is the bytefallback for `\t`
  92. vocab["\t"] = vocab.get("<0x09>")
  93. merges = generate_merges(vocab, vocab_scores)
  94. return vocab, merges
  95. def check_number_comma(piece: str) -> bool:
  96. return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
  97. class Converter:
  98. def __init__(self, original_tokenizer):
  99. self.original_tokenizer = original_tokenizer
  100. def converted(self) -> Tokenizer:
  101. raise NotImplementedError()
  102. class BertConverter(Converter):
  103. def converted(self) -> Tokenizer:
  104. vocab = self.original_tokenizer.vocab
  105. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  106. tokenize_chinese_chars = False
  107. strip_accents = False
  108. do_lower_case = False
  109. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  110. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  111. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  112. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  113. tokenizer.normalizer = normalizers.BertNormalizer(
  114. clean_text=True,
  115. handle_chinese_chars=tokenize_chinese_chars,
  116. strip_accents=strip_accents,
  117. lowercase=do_lower_case,
  118. )
  119. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  120. cls = str(self.original_tokenizer.cls_token)
  121. sep = str(self.original_tokenizer.sep_token)
  122. cls_token_id = self.original_tokenizer.cls_token_id
  123. sep_token_id = self.original_tokenizer.sep_token_id
  124. tokenizer.post_processor = processors.TemplateProcessing(
  125. single=f"{cls}:0 $A:0 {sep}:0",
  126. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  127. special_tokens=[
  128. (cls, cls_token_id),
  129. (sep, sep_token_id),
  130. ],
  131. )
  132. tokenizer.decoder = decoders.WordPiece(prefix="##")
  133. return tokenizer
  134. class SplinterConverter(Converter):
  135. def converted(self) -> Tokenizer:
  136. vocab = self.original_tokenizer.vocab
  137. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  138. tokenize_chinese_chars = False
  139. strip_accents = False
  140. do_lower_case = False
  141. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  142. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  143. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  144. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  145. tokenizer.normalizer = normalizers.BertNormalizer(
  146. clean_text=True,
  147. handle_chinese_chars=tokenize_chinese_chars,
  148. strip_accents=strip_accents,
  149. lowercase=do_lower_case,
  150. )
  151. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  152. cls = str(self.original_tokenizer.cls_token)
  153. sep = str(self.original_tokenizer.sep_token)
  154. question = str(self.original_tokenizer.question_token)
  155. dot = "."
  156. cls_token_id = self.original_tokenizer.cls_token_id
  157. sep_token_id = self.original_tokenizer.sep_token_id
  158. question_token_id = self.original_tokenizer.question_token_id
  159. dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
  160. if self.original_tokenizer.padding_side == "right":
  161. pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
  162. else:
  163. pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
  164. tokenizer.post_processor = processors.TemplateProcessing(
  165. single=f"{cls}:0 $A:0 {sep}:0",
  166. pair=pair,
  167. special_tokens=[
  168. (cls, cls_token_id),
  169. (sep, sep_token_id),
  170. (question, question_token_id),
  171. (dot, dot_token_id),
  172. ],
  173. )
  174. tokenizer.decoder = decoders.WordPiece(prefix="##")
  175. return tokenizer
  176. class FunnelConverter(Converter):
  177. def converted(self) -> Tokenizer:
  178. vocab = self.original_tokenizer.vocab
  179. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  180. tokenize_chinese_chars = False
  181. strip_accents = False
  182. do_lower_case = False
  183. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  184. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  185. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  186. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  187. tokenizer.normalizer = normalizers.BertNormalizer(
  188. clean_text=True,
  189. handle_chinese_chars=tokenize_chinese_chars,
  190. strip_accents=strip_accents,
  191. lowercase=do_lower_case,
  192. )
  193. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  194. cls = str(self.original_tokenizer.cls_token)
  195. sep = str(self.original_tokenizer.sep_token)
  196. cls_token_id = self.original_tokenizer.cls_token_id
  197. sep_token_id = self.original_tokenizer.sep_token_id
  198. tokenizer.post_processor = processors.TemplateProcessing(
  199. single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
  200. pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
  201. special_tokens=[
  202. (cls, cls_token_id),
  203. (sep, sep_token_id),
  204. ],
  205. )
  206. tokenizer.decoder = decoders.WordPiece(prefix="##")
  207. return tokenizer
  208. class MPNetConverter(Converter):
  209. def converted(self) -> Tokenizer:
  210. vocab = self.original_tokenizer.vocab
  211. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  212. tokenize_chinese_chars = False
  213. strip_accents = False
  214. do_lower_case = False
  215. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  216. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  217. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  218. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  219. tokenizer.normalizer = normalizers.BertNormalizer(
  220. clean_text=True,
  221. handle_chinese_chars=tokenize_chinese_chars,
  222. strip_accents=strip_accents,
  223. lowercase=do_lower_case,
  224. )
  225. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  226. cls = str(self.original_tokenizer.cls_token)
  227. sep = str(self.original_tokenizer.sep_token)
  228. cls_token_id = self.original_tokenizer.cls_token_id
  229. sep_token_id = self.original_tokenizer.sep_token_id
  230. tokenizer.post_processor = processors.TemplateProcessing(
  231. single=f"{cls}:0 $A:0 {sep}:0",
  232. pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
  233. special_tokens=[
  234. (cls, cls_token_id),
  235. (sep, sep_token_id),
  236. ],
  237. )
  238. tokenizer.decoder = decoders.WordPiece(prefix="##")
  239. return tokenizer
  240. class OpenAIGPTConverter(Converter):
  241. def converted(self) -> Tokenizer:
  242. vocab = self.original_tokenizer.encoder
  243. merges = list(self.original_tokenizer.bpe_ranks.keys())
  244. unk_token = self.original_tokenizer.unk_token
  245. tokenizer = Tokenizer(
  246. BPE(
  247. vocab=vocab,
  248. merges=merges,
  249. dropout=None,
  250. unk_token=str(unk_token),
  251. end_of_word_suffix="</w>",
  252. fuse_unk=False,
  253. )
  254. )
  255. if tokenizer.token_to_id(str(unk_token)) is not None:
  256. tokenizer.add_special_tokens([str(unk_token)])
  257. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
  258. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  259. tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
  260. return tokenizer
  261. class GPT2Converter(Converter):
  262. def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
  263. if not vocab:
  264. vocab = self.original_tokenizer.encoder
  265. if not merges:
  266. merges = list(self.original_tokenizer.bpe_ranks)
  267. tokenizer = Tokenizer(
  268. BPE(
  269. vocab=vocab,
  270. merges=merges,
  271. dropout=None,
  272. continuing_subword_prefix="",
  273. end_of_word_suffix="",
  274. fuse_unk=False,
  275. )
  276. )
  277. add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
  278. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
  279. tokenizer.decoder = decoders.ByteLevel()
  280. if getattr(self.original_tokenizer, "add_bos_token", False):
  281. bos = self.original_tokenizer.bos_token
  282. bos_token_id = self.original_tokenizer.bos_token_id
  283. tokenizer.post_processor = processors.TemplateProcessing(
  284. single=f"{bos}:0 $A:0",
  285. pair=f"{bos}:0 $A:0 $B:1",
  286. special_tokens=[
  287. (bos, bos_token_id),
  288. ],
  289. )
  290. else:
  291. # XXX trim_offsets=False actually means this post_processor doesn't
  292. # really do anything.
  293. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  294. return tokenizer
  295. class HerbertConverter(Converter):
  296. def converted(self) -> Tokenizer:
  297. tokenizer_info_str = "#version:"
  298. token_suffix = "</w>"
  299. vocab = self.original_tokenizer.encoder
  300. merges = list(self.original_tokenizer.bpe_ranks.keys())
  301. if tokenizer_info_str in merges[0][0]:
  302. merges = merges[1:]
  303. tokenizer = Tokenizer(
  304. BPE(
  305. vocab,
  306. merges,
  307. dropout=None,
  308. unk_token=self.original_tokenizer.unk_token,
  309. end_of_word_suffix=token_suffix,
  310. )
  311. )
  312. tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
  313. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  314. tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
  315. tokenizer.post_processor = processors.BertProcessing(
  316. sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
  317. cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
  318. )
  319. return tokenizer
  320. class Qwen2Converter(Converter):
  321. def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
  322. if not vocab:
  323. vocab = self.original_tokenizer.encoder
  324. if not merges:
  325. merges = list(self.original_tokenizer.bpe_ranks.keys())
  326. tokenizer = Tokenizer(
  327. BPE(
  328. vocab=vocab,
  329. merges=merges,
  330. dropout=None,
  331. unk_token=None,
  332. continuing_subword_prefix="",
  333. end_of_word_suffix="",
  334. fuse_unk=False,
  335. byte_fallback=False,
  336. )
  337. )
  338. tokenizer.normalizer = normalizers.NFC()
  339. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  340. [
  341. pre_tokenizers.Split(
  342. Regex(
  343. r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
  344. ),
  345. behavior="isolated",
  346. invert=False,
  347. ),
  348. pre_tokenizers.ByteLevel(
  349. add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
  350. use_regex=False,
  351. ),
  352. ]
  353. )
  354. tokenizer.decoder = decoders.ByteLevel()
  355. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  356. return tokenizer
  357. class RobertaConverter(Converter):
  358. def converted(self) -> Tokenizer:
  359. ot = self.original_tokenizer
  360. vocab = ot.encoder
  361. merges = list(ot.bpe_ranks.keys())
  362. tokenizer = Tokenizer(
  363. BPE(
  364. vocab=vocab,
  365. merges=merges,
  366. dropout=None,
  367. continuing_subword_prefix="",
  368. end_of_word_suffix="",
  369. fuse_unk=False,
  370. )
  371. )
  372. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  373. tokenizer.decoder = decoders.ByteLevel()
  374. tokenizer.post_processor = processors.RobertaProcessing(
  375. sep=(ot.sep_token, ot.sep_token_id),
  376. cls=(ot.cls_token, ot.cls_token_id),
  377. add_prefix_space=ot.add_prefix_space,
  378. trim_offsets=True, # True by default on Roberta (historical)
  379. )
  380. return tokenizer
  381. class RoFormerConverter(Converter):
  382. def converted(self) -> Tokenizer:
  383. from .models.roformer.tokenization_utils import JiebaPreTokenizer
  384. vocab = self.original_tokenizer.vocab
  385. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  386. strip_accents = False
  387. do_lower_case = False
  388. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  389. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  390. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  391. tokenizer.normalizer = normalizers.BertNormalizer(
  392. clean_text=True,
  393. handle_chinese_chars=False,
  394. strip_accents=strip_accents,
  395. lowercase=do_lower_case,
  396. )
  397. tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
  398. cls = str(self.original_tokenizer.cls_token)
  399. sep = str(self.original_tokenizer.sep_token)
  400. cls_token_id = self.original_tokenizer.cls_token_id
  401. sep_token_id = self.original_tokenizer.sep_token_id
  402. tokenizer.post_processor = processors.TemplateProcessing(
  403. single=f"{cls}:0 $A:0 {sep}:0",
  404. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  405. special_tokens=[
  406. (cls, cls_token_id),
  407. (sep, sep_token_id),
  408. ],
  409. )
  410. tokenizer.decoder = decoders.WordPiece(prefix="##")
  411. return tokenizer
  412. class DebertaConverter(Converter):
  413. def converted(self) -> Tokenizer:
  414. ot = self.original_tokenizer
  415. vocab = ot.encoder
  416. merges = list(ot.bpe_ranks.keys())
  417. tokenizer = Tokenizer(
  418. BPE(
  419. vocab=vocab,
  420. merges=merges,
  421. dropout=None,
  422. continuing_subword_prefix="",
  423. end_of_word_suffix="",
  424. fuse_unk=False,
  425. )
  426. )
  427. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  428. tokenizer.decoder = decoders.ByteLevel()
  429. tokenizer.post_processor = processors.TemplateProcessing(
  430. single="[CLS]:0 $A:0 [SEP]:0",
  431. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  432. special_tokens=[
  433. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  434. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  435. ],
  436. )
  437. return tokenizer
  438. class SpmConverter(Converter):
  439. handle_byte_fallback = False
  440. SpmExtractor = SentencePieceExtractor
  441. special_tokens = {}
  442. def __init__(self, *args):
  443. requires_backends(self, "protobuf")
  444. super().__init__(*args)
  445. # from .utils import sentencepiece_model_pb2 as model_pb2
  446. model_pb2 = import_protobuf()
  447. m = model_pb2.ModelProto()
  448. with open(self.original_tokenizer.vocab_file, "rb") as f:
  449. m.ParseFromString(f.read())
  450. self.proto = m
  451. if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
  452. warnings.warn(
  453. "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
  454. " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
  455. " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
  456. "unknown tokens into a sequence of byte tokens matching the original piece of text."
  457. )
  458. def vocab(self, proto):
  459. return [(piece.piece, piece.score) for piece in proto.pieces]
  460. def unk_id(self, proto):
  461. return proto.trainer_spec.unk_id
  462. def tokenizer(self, proto):
  463. model_type = proto.trainer_spec.model_type
  464. vocab_scores = self.vocab(proto)
  465. if model_type == 1:
  466. tokenizer = Tokenizer(
  467. Unigram(
  468. vocab_scores,
  469. unk_id=self.unk_id(proto),
  470. byte_fallback=self.handle_byte_fallback,
  471. )
  472. )
  473. elif model_type == 2:
  474. _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
  475. bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
  476. tokenizer = Tokenizer(
  477. BPE(
  478. bpe_vocab,
  479. merges,
  480. unk_token=proto.trainer_spec.unk_piece,
  481. fuse_unk=True,
  482. byte_fallback=self.handle_byte_fallback,
  483. dropout=None,
  484. )
  485. )
  486. else:
  487. raise Exception(
  488. "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
  489. )
  490. # control tokens are special
  491. # user defined symbols are not
  492. # both user and control tokens are AddedTokens
  493. # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
  494. spm_added_tokens = [
  495. (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
  496. for id, p in enumerate(proto.pieces)
  497. if p.type in [3, 4]
  498. ]
  499. tokenizer.add_tokens(
  500. [
  501. AddedToken(token, normalized=False, special=special)
  502. for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
  503. ]
  504. )
  505. return tokenizer
  506. def normalizer(self, proto):
  507. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  508. _normalizers = [
  509. normalizers.Strip(left=False, right=True), # stripping is important
  510. normalizers.Replace(Regex(" {2,}"), "▁"),
  511. ]
  512. if not precompiled_charsmap:
  513. return normalizers.Sequence(_normalizers)
  514. else:
  515. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  516. def pre_tokenizer(self, replacement, add_prefix_space):
  517. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  518. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  519. def post_processor(self):
  520. return None
  521. def decoder(self, replacement, add_prefix_space):
  522. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  523. return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
  524. def converted(self) -> Tokenizer:
  525. tokenizer = self.tokenizer(self.proto)
  526. # Tokenizer assemble
  527. normalizer = self.normalizer(self.proto)
  528. if normalizer is not None:
  529. tokenizer.normalizer = normalizer
  530. replacement = "▁"
  531. add_prefix_space = True
  532. if hasattr(self.original_tokenizer, "add_prefix_space"):
  533. add_prefix_space = self.original_tokenizer.add_prefix_space
  534. pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
  535. if pre_tokenizer is not None:
  536. tokenizer.pre_tokenizer = pre_tokenizer
  537. tokenizer.decoder = self.decoder(replacement, add_prefix_space)
  538. post_processor = self.post_processor()
  539. if post_processor:
  540. tokenizer.post_processor = post_processor
  541. return tokenizer
  542. class AlbertConverter(SpmConverter):
  543. def vocab(self, proto):
  544. return [
  545. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  546. for piece in proto.pieces
  547. ]
  548. def normalizer(self, proto):
  549. list_normalizers = [
  550. normalizers.Replace("``", '"'),
  551. normalizers.Replace("''", '"'),
  552. ]
  553. if not self.original_tokenizer.keep_accents:
  554. list_normalizers.append(normalizers.NFKD())
  555. list_normalizers.append(normalizers.StripAccents())
  556. if self.original_tokenizer.do_lower_case:
  557. list_normalizers.append(normalizers.Lowercase())
  558. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  559. if precompiled_charsmap:
  560. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  561. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  562. return normalizers.Sequence(list_normalizers)
  563. def post_processor(self):
  564. return processors.TemplateProcessing(
  565. single="[CLS]:0 $A:0 [SEP]:0",
  566. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  567. special_tokens=[
  568. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  569. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  570. ],
  571. )
  572. class BarthezConverter(SpmConverter):
  573. def unk_id(self, proto):
  574. unk_id = 3
  575. return unk_id
  576. def post_processor(self):
  577. return processors.TemplateProcessing(
  578. single="<s> $A </s>",
  579. pair="<s> $A </s> </s> $B </s>",
  580. special_tokens=[
  581. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  582. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  583. ],
  584. )
  585. class CamembertConverter(SpmConverter):
  586. def vocab(self, proto):
  587. vocab = [
  588. ("<s>NOTUSED", 0.0),
  589. ("<pad>", 0.0),
  590. ("</s>NOTUSED", 0.0),
  591. ("<unk>", 0.0),
  592. ("<unk>NOTUSED", -100),
  593. ]
  594. # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
  595. vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
  596. vocab += [("<mask>", 0.0)]
  597. return vocab
  598. def unk_id(self, proto):
  599. # See vocab unk position
  600. return 3
  601. def post_processor(self):
  602. return processors.TemplateProcessing(
  603. single="<s> $A </s>",
  604. pair="<s> $A </s> </s> $B </s>",
  605. special_tokens=[
  606. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  607. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  608. ],
  609. )
  610. class DebertaV2Converter(SpmConverter):
  611. def pre_tokenizer(self, replacement, add_prefix_space):
  612. list_pretokenizers = []
  613. if self.original_tokenizer.split_by_punct:
  614. list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
  615. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  616. list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
  617. return pre_tokenizers.Sequence(list_pretokenizers)
  618. def normalizer(self, proto):
  619. list_normalizers = []
  620. if self.original_tokenizer.do_lower_case:
  621. list_normalizers.append(normalizers.Lowercase())
  622. list_normalizers.append(normalizers.Strip())
  623. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  624. if precompiled_charsmap:
  625. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  626. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  627. return normalizers.Sequence(list_normalizers)
  628. def post_processor(self):
  629. return processors.TemplateProcessing(
  630. single="[CLS]:0 $A:0 [SEP]:0",
  631. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  632. special_tokens=[
  633. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  634. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  635. ],
  636. )
  637. class MBartConverter(SpmConverter):
  638. def vocab(self, proto):
  639. vocab = [
  640. ("<s>", 0.0),
  641. ("<pad>", 0.0),
  642. ("</s>", 0.0),
  643. ("<unk>", 0.0),
  644. ]
  645. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  646. vocab += [
  647. ("ar_AR", 0.0),
  648. ("cs_CZ", 0.0),
  649. ("de_DE", 0.0),
  650. ("en_XX", 0.0),
  651. ("es_XX", 0.0),
  652. ("et_EE", 0.0),
  653. ("fi_FI", 0.0),
  654. ("fr_XX", 0.0),
  655. ("gu_IN", 0.0),
  656. ("hi_IN", 0.0),
  657. ("it_IT", 0.0),
  658. ("ja_XX", 0.0),
  659. ("kk_KZ", 0.0),
  660. ("ko_KR", 0.0),
  661. ("lt_LT", 0.0),
  662. ("lv_LV", 0.0),
  663. ("my_MM", 0.0),
  664. ("ne_NP", 0.0),
  665. ("nl_XX", 0.0),
  666. ("ro_RO", 0.0),
  667. ("ru_RU", 0.0),
  668. ("si_LK", 0.0),
  669. ("tr_TR", 0.0),
  670. ("vi_VN", 0.0),
  671. ("zh_CN", 0.0),
  672. ]
  673. vocab += [("<mask>", 0.0)]
  674. return vocab
  675. def unk_id(self, proto):
  676. return 3
  677. def post_processor(self):
  678. return processors.TemplateProcessing(
  679. single="$A </s> en_XX",
  680. pair="$A $B </s> en_XX",
  681. special_tokens=[
  682. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  683. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  684. ],
  685. )
  686. class MBart50Converter(SpmConverter):
  687. def vocab(self, proto):
  688. vocab = [
  689. ("<s>", 0.0),
  690. ("<pad>", 0.0),
  691. ("</s>", 0.0),
  692. ("<unk>", 0.0),
  693. ]
  694. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  695. vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
  696. vocab += [("<mask>", 0.0)]
  697. return vocab
  698. def unk_id(self, proto):
  699. return 3
  700. def post_processor(self):
  701. return processors.TemplateProcessing(
  702. single="en_XX $A </s>",
  703. pair="en_XX $A $B </s>",
  704. special_tokens=[
  705. ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
  706. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  707. ],
  708. )
  709. class NllbConverter(SpmConverter):
  710. def vocab(self, proto):
  711. vocab = [
  712. ("<s>", 0.0),
  713. ("<pad>", 0.0),
  714. ("</s>", 0.0),
  715. ("<unk>", 0.0),
  716. ]
  717. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  718. return vocab
  719. def unk_id(self, proto):
  720. return 3
  721. def post_processor(self):
  722. return processors.TemplateProcessing(
  723. single="eng_Latn $A </s>",
  724. pair="eng_Latn $A $B </s>",
  725. special_tokens=[
  726. ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
  727. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  728. ],
  729. )
  730. class SeamlessM4TConverter(SpmConverter):
  731. def vocab(self, proto):
  732. vocab = [
  733. ("<pad>", 0.0),
  734. ("<unk>", 0.0),
  735. ("<s>", 0.0),
  736. ("</s>", 0.0),
  737. ]
  738. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  739. return vocab
  740. def unk_id(self, proto):
  741. return self.original_tokenizer.unk_token_id
  742. def post_processor(self):
  743. return processors.TemplateProcessing(
  744. single="__eng__ $A </s>",
  745. pair="__eng__ $A $B </s>",
  746. special_tokens=[
  747. ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
  748. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  749. ],
  750. )
  751. class XLMRobertaConverter(SpmConverter):
  752. def vocab(self, proto):
  753. vocab = [
  754. ("<s>", 0.0),
  755. ("<pad>", 0.0),
  756. ("</s>", 0.0),
  757. ("<unk>", 0.0),
  758. ]
  759. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  760. vocab += [("<mask>", 0.0)]
  761. return vocab
  762. def unk_id(self, proto):
  763. unk_id = 3
  764. return unk_id
  765. def post_processor(self):
  766. return processors.TemplateProcessing(
  767. single="<s> $A </s>",
  768. pair="<s> $A </s> </s> $B </s>",
  769. special_tokens=[
  770. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  771. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  772. ],
  773. )
  774. class XLNetConverter(SpmConverter):
  775. def vocab(self, proto):
  776. return [
  777. (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
  778. for piece in proto.pieces
  779. ]
  780. def normalizer(self, proto):
  781. list_normalizers = [
  782. normalizers.Replace("``", '"'),
  783. normalizers.Replace("''", '"'),
  784. ]
  785. if not self.original_tokenizer.keep_accents:
  786. list_normalizers.append(normalizers.NFKD())
  787. list_normalizers.append(normalizers.StripAccents())
  788. if self.original_tokenizer.do_lower_case:
  789. list_normalizers.append(normalizers.Lowercase())
  790. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  791. if precompiled_charsmap:
  792. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  793. list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
  794. return normalizers.Sequence(list_normalizers)
  795. def post_processor(self):
  796. return processors.TemplateProcessing(
  797. single="$A:0 <sep>:0 <cls>:2",
  798. pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
  799. special_tokens=[
  800. ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
  801. ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
  802. ],
  803. )
  804. class ReformerConverter(SpmConverter):
  805. pass
  806. class RemBertConverter(SpmConverter):
  807. # Inspired from AlbertConverter
  808. def normalizer(self, proto):
  809. list_normalizers = [
  810. normalizers.Replace("``", '"'),
  811. normalizers.Replace("''", '"'),
  812. normalizers.Replace(Regex(" {2,}"), " "),
  813. ]
  814. if not self.original_tokenizer.keep_accents:
  815. list_normalizers.append(normalizers.NFKD())
  816. list_normalizers.append(normalizers.StripAccents())
  817. if self.original_tokenizer.do_lower_case:
  818. list_normalizers.append(normalizers.Lowercase())
  819. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  820. if precompiled_charsmap:
  821. list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
  822. return normalizers.Sequence(list_normalizers)
  823. def post_processor(self):
  824. return processors.TemplateProcessing(
  825. single="[CLS]:0 $A:0 [SEP]:0",
  826. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  827. special_tokens=[
  828. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  829. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  830. ],
  831. )
  832. class BertGenerationConverter(SpmConverter):
  833. pass
  834. class PegasusConverter(SpmConverter):
  835. def vocab(self, proto):
  836. vocab = [
  837. (self.original_tokenizer.pad_token, 0.0),
  838. (self.original_tokenizer.eos_token, 0.0),
  839. ]
  840. if self.original_tokenizer.mask_token_sent is not None:
  841. vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
  842. if (
  843. self.original_tokenizer.mask_token is not None
  844. and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
  845. ):
  846. vocab += [(self.original_tokenizer.mask_token, 0.0)]
  847. vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
  848. vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
  849. return vocab
  850. def unk_id(self, proto):
  851. return proto.trainer_spec.unk_id + self.original_tokenizer.offset
  852. def pre_tokenizer(self, replacement, add_prefix_space):
  853. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  854. return pre_tokenizers.Sequence(
  855. [
  856. pre_tokenizers.WhitespaceSplit(),
  857. pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
  858. ]
  859. )
  860. def post_processor(self):
  861. eos = self.original_tokenizer.eos_token
  862. special_tokens = [
  863. (eos, self.original_tokenizer.eos_token_id),
  864. ]
  865. return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
  866. class T5Converter(SpmConverter):
  867. def vocab(self, proto):
  868. num_extra_ids = self.original_tokenizer._extra_ids
  869. vocab = [(piece.piece, piece.score) for piece in proto.pieces]
  870. vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
  871. return vocab
  872. def post_processor(self):
  873. return processors.TemplateProcessing(
  874. single=["$A", "</s>"],
  875. pair=["$A", "</s>", "$B", "</s>"],
  876. special_tokens=[
  877. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  878. ],
  879. )
  880. class UdopConverter(SpmConverter):
  881. def post_processor(self):
  882. return processors.TemplateProcessing(
  883. single=["$A", "</s>"],
  884. pair=["$A", "</s>", "$B", "</s>"],
  885. special_tokens=[
  886. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  887. ],
  888. )
  889. class WhisperConverter(Converter):
  890. def converted(self) -> Tokenizer:
  891. vocab = self.original_tokenizer.encoder
  892. merges = list(self.original_tokenizer.bpe_ranks.keys())
  893. tokenizer = Tokenizer(
  894. BPE(
  895. vocab=vocab,
  896. merges=merges,
  897. dropout=None,
  898. continuing_subword_prefix="",
  899. end_of_word_suffix="",
  900. fuse_unk=False,
  901. )
  902. )
  903. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
  904. tokenizer.decoder = decoders.ByteLevel()
  905. prefix_token_ids = self.original_tokenizer.prefix_tokens
  906. prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
  907. eos = self.original_tokenizer.eos_token
  908. eos_token_id = self.original_tokenizer.eos_token_id
  909. prefix_template = " ".join([f"{token}:0" for token in prefixes])
  910. tokenizer.post_processor = processors.TemplateProcessing(
  911. single=f"{prefix_template} $A:0 {eos}:0",
  912. pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
  913. special_tokens=[
  914. (eos, eos_token_id),
  915. *zip(prefixes, prefix_token_ids),
  916. ],
  917. )
  918. return tokenizer
  919. class BigBirdConverter(SpmConverter):
  920. def post_processor(self):
  921. return processors.TemplateProcessing(
  922. single="[CLS]:0 $A:0 [SEP]:0",
  923. pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
  924. special_tokens=[
  925. ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
  926. ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
  927. ],
  928. )
  929. class CLIPConverter(Converter):
  930. def converted(self) -> Tokenizer:
  931. vocab = self.original_tokenizer.encoder
  932. merges = list(self.original_tokenizer.bpe_ranks.keys())
  933. unk_token = self.original_tokenizer.unk_token
  934. tokenizer = Tokenizer(
  935. BPE(
  936. vocab=vocab,
  937. merges=merges,
  938. dropout=None,
  939. continuing_subword_prefix="",
  940. end_of_word_suffix="</w>",
  941. fuse_unk=False,
  942. unk_token=str(unk_token),
  943. )
  944. )
  945. tokenizer.normalizer = normalizers.Sequence(
  946. [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
  947. )
  948. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  949. [
  950. pre_tokenizers.Split(
  951. Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
  952. behavior="removed",
  953. invert=True,
  954. ),
  955. pre_tokenizers.ByteLevel(add_prefix_space=False),
  956. ]
  957. )
  958. tokenizer.decoder = decoders.ByteLevel()
  959. # Hack to have a ByteLevel and TemplaceProcessor
  960. tokenizer.post_processor = processors.RobertaProcessing(
  961. sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
  962. cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
  963. add_prefix_space=False,
  964. trim_offsets=False,
  965. )
  966. return tokenizer
  967. class LayoutLMv2Converter(Converter):
  968. def converted(self) -> Tokenizer:
  969. vocab = self.original_tokenizer.vocab
  970. tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
  971. tokenize_chinese_chars = False
  972. strip_accents = False
  973. do_lower_case = True
  974. if hasattr(self.original_tokenizer, "basic_tokenizer"):
  975. tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
  976. strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
  977. do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
  978. tokenizer.normalizer = normalizers.BertNormalizer(
  979. clean_text=True,
  980. handle_chinese_chars=tokenize_chinese_chars,
  981. strip_accents=strip_accents,
  982. lowercase=do_lower_case,
  983. )
  984. tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
  985. cls = str(self.original_tokenizer.cls_token)
  986. sep = str(self.original_tokenizer.sep_token)
  987. cls_token_id = self.original_tokenizer.cls_token_id
  988. sep_token_id = self.original_tokenizer.sep_token_id
  989. tokenizer.post_processor = processors.TemplateProcessing(
  990. single=f"{cls}:0 $A:0 {sep}:0",
  991. pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
  992. special_tokens=[
  993. (cls, cls_token_id),
  994. (sep, sep_token_id),
  995. ],
  996. )
  997. tokenizer.decoder = decoders.WordPiece(prefix="##")
  998. return tokenizer
  999. class BlenderbotConverter(Converter):
  1000. def converted(self) -> Tokenizer:
  1001. ot = self.original_tokenizer
  1002. vocab = ot.encoder
  1003. merges = list(ot.bpe_ranks.keys())
  1004. tokenizer = Tokenizer(
  1005. BPE(
  1006. vocab=vocab,
  1007. merges=merges,
  1008. dropout=None,
  1009. continuing_subword_prefix="",
  1010. end_of_word_suffix="",
  1011. fuse_unk=False,
  1012. )
  1013. )
  1014. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1015. tokenizer.decoder = decoders.ByteLevel()
  1016. tokenizer.post_processor = processors.TemplateProcessing(
  1017. single=f"$A:0 {ot.eos_token}:0",
  1018. special_tokens=[
  1019. (ot.eos_token, ot.eos_token_id),
  1020. ],
  1021. )
  1022. return tokenizer
  1023. class XGLMConverter(SpmConverter):
  1024. def vocab(self, proto):
  1025. vocab = [
  1026. ("<s>", 0.0),
  1027. ("<pad>", 0.0),
  1028. ("</s>", 0.0),
  1029. ("<unk>", 0.0),
  1030. ]
  1031. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1032. vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
  1033. return vocab
  1034. def unk_id(self, proto):
  1035. unk_id = 3
  1036. return unk_id
  1037. def post_processor(self):
  1038. return processors.TemplateProcessing(
  1039. single="</s> $A",
  1040. pair="</s> $A </s> </s> $B",
  1041. special_tokens=[
  1042. ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
  1043. ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
  1044. ],
  1045. )
  1046. class GemmaConvert(SpmConverter):
  1047. handle_byte_fallback = True
  1048. SpmExtractor = GemmaSentencePieceExtractor
  1049. # start and end of turn tokens must be marked as special
  1050. special_tokens = {"<start_of_turn>", "<end_of_turn>"}
  1051. """"
  1052. split_by_unicode_script: true
  1053. split_by_number: true
  1054. split_by_whitespace: true
  1055. treat_whitespace_as_suffix: false
  1056. allow_whitespace_only_pieces: true
  1057. split_digits: true
  1058. byte_fallback: true
  1059. """
  1060. def normalizer(self, proto):
  1061. return normalizers.Replace(" ", "▁")
  1062. def vocab(self, proto):
  1063. vocab = [
  1064. (self.original_tokenizer.pad_token, 0.0),
  1065. (self.original_tokenizer.eos_token, 0.0),
  1066. (self.original_tokenizer.bos_token, 0.0),
  1067. ]
  1068. for piece in proto.pieces[3:]:
  1069. if piece.piece == "<0x09>":
  1070. vocab += [("\t", piece.score)]
  1071. else:
  1072. vocab += [(piece.piece, piece.score)]
  1073. # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1074. return vocab
  1075. def pre_tokenizer(self, replacement, add_prefix_space):
  1076. return pre_tokenizers.Split(" ", "merged_with_previous")
  1077. def unk_id(self, proto):
  1078. unk_id = 3
  1079. return unk_id
  1080. def decoder(self, replacement, add_prefix_space):
  1081. return decoders.Sequence(
  1082. [
  1083. decoders.Replace("▁", " "),
  1084. decoders.ByteFallback(),
  1085. decoders.Fuse(),
  1086. ]
  1087. )
  1088. class LlamaConverter(SpmConverter):
  1089. handle_byte_fallback = True
  1090. def vocab(self, proto):
  1091. vocab = [
  1092. (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
  1093. (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
  1094. (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
  1095. ]
  1096. vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
  1097. return vocab
  1098. def unk_id(self, proto):
  1099. unk_id = 0
  1100. return unk_id
  1101. def decoder(self, replacement, add_prefix_space):
  1102. sequence = [
  1103. decoders.Replace("▁", " "),
  1104. decoders.ByteFallback(),
  1105. decoders.Fuse(),
  1106. ]
  1107. if add_prefix_space:
  1108. sequence += [decoders.Strip(content=" ", left=1)]
  1109. return decoders.Sequence(sequence)
  1110. def normalizer(self, proto):
  1111. if getattr(self.original_tokenizer, "legacy", True):
  1112. sequence = []
  1113. if getattr(self.original_tokenizer, "add_prefix_space", True):
  1114. sequence += [normalizers.Prepend(prepend="▁")]
  1115. sequence += [normalizers.Replace(pattern=" ", content="▁")]
  1116. return normalizers.Sequence(sequence)
  1117. return None # non-legacy, no normalizer
  1118. def pre_tokenizer(self, replacement, add_prefix_space):
  1119. if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
  1120. prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
  1121. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1122. return None
  1123. def post_processor(self):
  1124. # the processor is defined in the LlamaTokenizerFast class.
  1125. return None
  1126. class MarkupLMConverter(Converter):
  1127. def converted(self) -> Tokenizer:
  1128. ot = self.original_tokenizer
  1129. vocab = ot.encoder
  1130. merges = list(ot.bpe_ranks.keys())
  1131. tokenizer = Tokenizer(
  1132. BPE(
  1133. vocab=vocab,
  1134. merges=merges,
  1135. dropout=None,
  1136. continuing_subword_prefix="",
  1137. end_of_word_suffix="",
  1138. fuse_unk=False,
  1139. unk_token=self.original_tokenizer.unk_token,
  1140. )
  1141. )
  1142. tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
  1143. tokenizer.decoder = decoders.ByteLevel()
  1144. cls = str(self.original_tokenizer.cls_token)
  1145. sep = str(self.original_tokenizer.sep_token)
  1146. cls_token_id = self.original_tokenizer.cls_token_id
  1147. sep_token_id = self.original_tokenizer.sep_token_id
  1148. tokenizer.post_processor = processors.TemplateProcessing(
  1149. single=f"{cls} $A {sep}",
  1150. pair=f"{cls} $A {sep} $B {sep}",
  1151. special_tokens=[
  1152. (cls, cls_token_id),
  1153. (sep, sep_token_id),
  1154. ],
  1155. )
  1156. return tokenizer
  1157. class MoshiConverter(SpmConverter):
  1158. handle_byte_fallback = True
  1159. def __init__(self, vocab_file, model_max_length=None, **kwargs):
  1160. requires_backends(self, "protobuf")
  1161. Converter.__init__(self, vocab_file)
  1162. # from .utils import sentencepiece_model_pb2 as model_pb2
  1163. model_pb2 = import_protobuf()
  1164. m = model_pb2.ModelProto()
  1165. with open(vocab_file, "rb") as f:
  1166. m.ParseFromString(f.read())
  1167. self.proto = m
  1168. def normalizer(self, proto):
  1169. precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
  1170. _normalizers = [
  1171. normalizers.Replace(" ", "▁"),
  1172. ]
  1173. if not precompiled_charsmap:
  1174. return normalizers.Sequence(_normalizers)
  1175. else:
  1176. return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
  1177. def decoder(self, replacement, add_prefix_space):
  1178. sequence = [
  1179. decoders.Replace("▁", " "),
  1180. decoders.ByteFallback(),
  1181. decoders.Fuse(),
  1182. ]
  1183. if add_prefix_space:
  1184. sequence += [decoders.Strip(content=" ", left=1)]
  1185. return decoders.Sequence(sequence)
  1186. def pre_tokenizer(self, replacement, add_prefix_space):
  1187. prepend_scheme = "first"
  1188. return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
  1189. # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
  1190. def bytes_to_unicode():
  1191. """
  1192. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  1193. characters the bpe code barfs on.
  1194. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  1195. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  1196. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  1197. tables between utf-8 bytes and unicode strings.
  1198. """
  1199. bs = (
  1200. list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
  1201. )
  1202. cs = bs[:]
  1203. n = 0
  1204. for b in range(2**8):
  1205. if b not in bs:
  1206. bs.append(b)
  1207. cs.append(2**8 + n)
  1208. n += 1
  1209. cs = [chr(n) for n in cs]
  1210. return dict(zip(bs, cs))
  1211. class TikTokenConverter:
  1212. """
  1213. A general tiktoken converter.
  1214. """
  1215. def __init__(
  1216. self,
  1217. vocab_file=None,
  1218. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  1219. add_prefix_space=False,
  1220. additional_special_tokens=None,
  1221. *args,
  1222. **kwargs,
  1223. ):
  1224. super().__init__(*args)
  1225. self.vocab_file = vocab_file
  1226. self.pattern = pattern
  1227. self.add_prefix_space = add_prefix_space
  1228. self.additional_special_tokens = additional_special_tokens
  1229. def extract_vocab_merges_from_model(self, tiktoken_url: str):
  1230. try:
  1231. from tiktoken.load import load_tiktoken_bpe
  1232. except Exception:
  1233. raise ValueError(
  1234. "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
  1235. )
  1236. bpe_ranks = load_tiktoken_bpe(tiktoken_url)
  1237. byte_encoder = bytes_to_unicode()
  1238. def token_bytes_to_string(b):
  1239. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  1240. merges = []
  1241. vocab = {}
  1242. for token, rank in bpe_ranks.items():
  1243. vocab[token_bytes_to_string(token)] = rank
  1244. if len(token) == 1:
  1245. continue
  1246. local = []
  1247. for index in range(1, len(token)):
  1248. piece_l, piece_r = token[:index], token[index:]
  1249. if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
  1250. local.append((piece_l, piece_r, rank))
  1251. local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
  1252. merges.extend(local)
  1253. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  1254. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  1255. return vocab, merges
  1256. def tokenizer(self):
  1257. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
  1258. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  1259. if hasattr(tokenizer.model, "ignore_merges"):
  1260. tokenizer.model.ignore_merges = True
  1261. return tokenizer
  1262. def converted(self) -> Tokenizer:
  1263. tokenizer = self.tokenizer()
  1264. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  1265. [
  1266. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  1267. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  1268. ]
  1269. )
  1270. tokenizer.decoder = decoders.ByteLevel()
  1271. tokenizer.add_special_tokens(self.additional_special_tokens)
  1272. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  1273. return tokenizer
  1274. SLOW_TO_FAST_CONVERTERS = {
  1275. "AlbertTokenizer": AlbertConverter,
  1276. "BartTokenizer": RobertaConverter,
  1277. "BarthezTokenizer": BarthezConverter,
  1278. "BertTokenizer": BertConverter,
  1279. "BigBirdTokenizer": BigBirdConverter,
  1280. "BlenderbotTokenizer": BlenderbotConverter,
  1281. "CamembertTokenizer": CamembertConverter,
  1282. "CLIPTokenizer": CLIPConverter,
  1283. "CodeGenTokenizer": GPT2Converter,
  1284. "ConvBertTokenizer": BertConverter,
  1285. "DebertaTokenizer": DebertaConverter,
  1286. "DebertaV2Tokenizer": DebertaV2Converter,
  1287. "DistilBertTokenizer": BertConverter,
  1288. "DPRReaderTokenizer": BertConverter,
  1289. "DPRQuestionEncoderTokenizer": BertConverter,
  1290. "DPRContextEncoderTokenizer": BertConverter,
  1291. "ElectraTokenizer": BertConverter,
  1292. "FNetTokenizer": AlbertConverter,
  1293. "FunnelTokenizer": FunnelConverter,
  1294. "GPT2Tokenizer": GPT2Converter,
  1295. "HerbertTokenizer": HerbertConverter,
  1296. "LayoutLMTokenizer": BertConverter,
  1297. "LayoutLMv2Tokenizer": BertConverter,
  1298. "LayoutLMv3Tokenizer": RobertaConverter,
  1299. "LayoutXLMTokenizer": XLMRobertaConverter,
  1300. "LongformerTokenizer": RobertaConverter,
  1301. "LEDTokenizer": RobertaConverter,
  1302. "LxmertTokenizer": BertConverter,
  1303. "MarkupLMTokenizer": MarkupLMConverter,
  1304. "MBartTokenizer": MBartConverter,
  1305. "MBart50Tokenizer": MBart50Converter,
  1306. "MPNetTokenizer": MPNetConverter,
  1307. "MobileBertTokenizer": BertConverter,
  1308. "MvpTokenizer": RobertaConverter,
  1309. "NllbTokenizer": NllbConverter,
  1310. "OpenAIGPTTokenizer": OpenAIGPTConverter,
  1311. "PegasusTokenizer": PegasusConverter,
  1312. "Qwen2Tokenizer": Qwen2Converter,
  1313. "RealmTokenizer": BertConverter,
  1314. "ReformerTokenizer": ReformerConverter,
  1315. "RemBertTokenizer": RemBertConverter,
  1316. "RetriBertTokenizer": BertConverter,
  1317. "RobertaTokenizer": RobertaConverter,
  1318. "RoFormerTokenizer": RoFormerConverter,
  1319. "SeamlessM4TTokenizer": SeamlessM4TConverter,
  1320. "SqueezeBertTokenizer": BertConverter,
  1321. "T5Tokenizer": T5Converter,
  1322. "UdopTokenizer": UdopConverter,
  1323. "WhisperTokenizer": WhisperConverter,
  1324. "XLMRobertaTokenizer": XLMRobertaConverter,
  1325. "XLNetTokenizer": XLNetConverter,
  1326. "SplinterTokenizer": SplinterConverter,
  1327. "XGLMTokenizer": XGLMConverter,
  1328. "LlamaTokenizer": LlamaConverter,
  1329. "CodeLlamaTokenizer": LlamaConverter,
  1330. "GemmaTokenizer": GemmaConvert,
  1331. "Phi3Tokenizer": LlamaConverter,
  1332. }
  1333. def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
  1334. """
  1335. Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
  1336. Args:
  1337. transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
  1338. Instance of a slow tokenizer to convert in the backend tokenizer for
  1339. [`~tokenization_utils_base.PreTrainedTokenizerFast`].
  1340. from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
  1341. Defaults to False.
  1342. Return:
  1343. A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
  1344. [`~tokenization_utils_base.PreTrainedTokenizerFast`]
  1345. """
  1346. tokenizer_class_name = transformer_tokenizer.__class__.__name__
  1347. if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
  1348. converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
  1349. return converter_class(transformer_tokenizer).converted()
  1350. else:
  1351. try:
  1352. logger.info("Converting from Tiktoken")
  1353. return TikTokenConverter(
  1354. vocab_file=transformer_tokenizer.vocab_file,
  1355. additional_special_tokens=transformer_tokenizer.additional_special_tokens,
  1356. ).converted()
  1357. except Exception:
  1358. raise ValueError(
  1359. f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
  1360. f"with a SentencePiece tokenizer.model file."
  1361. f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
  1362. )