| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- # coding=utf-8
- # Copyright 2022, UCLA NLP, The Facebook AI Research Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- from shutil import copyfile
- from typing import Any, Dict, List, Optional, Tuple
- import sentencepiece as spm
- from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
- from ...utils import logging
- logger = logging.get_logger(__name__)
- SPIECE_UNDERLINE = "▁"
- VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}
- FAIRSEQ_LANGUAGE_CODES = {
- "base": ["__java__", "__python__", "__en_XX__"],
- "multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"],
- }
- FAIRSEQ_LANGUAGE_CODES_MAP = {
- "java": "__java__",
- "python": "__python__",
- "en_XX": "__en_XX__",
- "javascript": "__javascript__",
- "php": "__php__",
- "ruby": "__ruby__",
- "go": "__go__",
- }
- class PLBartTokenizer(PreTrainedTokenizer):
- """
- Construct an PLBART tokenizer.
- Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
- [SentencePiece](https://github.com/google/sentencepiece).
- The tokenization method is `<tokens> <eos> <language code>` for source language documents, and `<language code>
- <tokens> <eos>` for target language documents.
- Args:
- vocab_file (`str`):
- Path to the vocabulary file.
- src_lang (`str`, *optional*):
- A string representing the source language.
- tgt_lang (`str`, *optional*):
- A string representing the target language.
- bos_token (`str`, *optional*, defaults to `"<s>"`):
- The start of sequence token.
- eos_token (`str`, *optional*, defaults to `"</s>"`):
- The end of sequence token.
- sep_token (`str`, *optional*, defaults to `"</s>"`):
- The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
- sequence classification or for a text and a question for question answering. It is also used as the last
- token of a sequence built with special tokens.
- cls_token (`str`, *optional*, defaults to `"<s>"`):
- The cls token, which is a special token used as the first token for all tasks.
- unk_token (`str`, *optional*, defaults to `"<unk>"`):
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
- token instead.
- pad_token (`str`, *optional*, defaults to `"<pad>"`):
- The token used for padding, for example when batching sequences of different lengths.
- mask_token(`str`, *optional*, defaults to `"<mask>"`):
- The token used for masking values. This is the token used when training this model with masking tasks. This
- is only used in the `"base"` tokenizer type. For `"multi"` tokenizer, masking is never done for the
- downstream tasks.
- language_codes (`str`, *optional*, defaults to `"base"`):
- What language codes to use. Should be one of `"base"` or `"multi"`.
- sp_model_kwargs (`dict`, *optional*):
- Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
- SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
- to set:
- - `enable_sampling`: Enable subword regularization.
- - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- - `nbest_size = {0,1}`: No sampling is performed.
- - `nbest_size > 1`: samples from the nbest_size results.
- - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
- using forward-filtering-and-backward-sampling algorithm.
- - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
- BPE-dropout.
- Examples:
- ```python
- >>> from transformers import PLBartTokenizer
- >>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
- >>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
- >>> expected_translation_english = "Returns the maximum value of a b c."
- >>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
- ```"""
- vocab_files_names = VOCAB_FILES_NAMES
- model_input_names = ["input_ids", "attention_mask"]
- prefix_tokens: List[int] = []
- suffix_tokens: List[int] = []
- def __init__(
- self,
- vocab_file,
- bos_token="<s>",
- eos_token="</s>",
- sep_token="</s>",
- cls_token="<s>",
- unk_token="<unk>",
- pad_token="<pad>",
- mask_token="<mask>",
- language_codes="base",
- tokenizer_file=None,
- src_lang=None,
- tgt_lang=None,
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
- additional_special_tokens=None,
- clean_up_tokenization_spaces=True,
- **kwargs,
- ):
- # Mask token behave like a normal word, i.e. include the space before it
- mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
- self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
- src_lang = self._convert_lang_code_special_format(src_lang)
- tgt_lang = self._convert_lang_code_special_format(tgt_lang)
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
- self.sp_model.Load(str(vocab_file))
- self.vocab_file = vocab_file
- self.language_codes = language_codes
- fairseq_language_codes = FAIRSEQ_LANGUAGE_CODES[self.language_codes]
- # Original fairseq vocab and spm vocab must be "aligned":
- # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
- # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
- # fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
- # spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
- # Mimic fairseq token-to-id alignment for the first 4 token
- self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
- # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
- self.fairseq_offset = 1
- self.sp_model_size = len(self.sp_model)
- self.lang_code_to_id = {
- code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(fairseq_language_codes)
- }
- self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
- if self.language_codes == "base":
- self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
- self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
- self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
- _additional_special_tokens = list(self.lang_code_to_id.keys())
- if additional_special_tokens is not None:
- # Only add those special tokens if they are not already there.
- _additional_special_tokens.extend(
- [t for t in additional_special_tokens if t not in _additional_special_tokens]
- )
- if self.language_codes == "base":
- self._src_lang = src_lang
- self.cur_lang_code_id = (
- self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang
- )
- else:
- self._src_lang = src_lang if src_lang is not None else "__en_XX__"
- self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
- super().__init__(
- bos_token=bos_token,
- eos_token=eos_token,
- unk_token=unk_token,
- sep_token=sep_token,
- cls_token=cls_token,
- pad_token=pad_token,
- mask_token=mask_token,
- language_codes=language_codes,
- tokenizer_file=tokenizer_file,
- src_lang=src_lang,
- tgt_lang=tgt_lang,
- additional_special_tokens=_additional_special_tokens,
- sp_model_kwargs=self.sp_model_kwargs,
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
- **kwargs,
- )
- self.tgt_lang = tgt_lang
- self.set_src_lang_special_tokens(self._src_lang)
- def __getstate__(self):
- state = self.__dict__.copy()
- state["sp_model"] = None
- state["sp_model_proto"] = self.sp_model.serialized_model_proto()
- return state
- def __setstate__(self, d):
- self.__dict__ = d
- # for backward compatibility
- if not hasattr(self, "sp_model_kwargs"):
- self.sp_model_kwargs = {}
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
- self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
- @property
- def vocab_size(self):
- if self.language_codes == "base":
- return (
- len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1
- ) # Plus 1 for the mask token
- else:
- return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
- @property
- def src_lang(self) -> str:
- return self._src_lang
- @src_lang.setter
- def src_lang(self, new_src_lang: str) -> None:
- new_src_lang = self._convert_lang_code_special_format(new_src_lang)
- self._src_lang = new_src_lang
- self.set_src_lang_special_tokens(self._src_lang)
- def get_special_tokens_mask(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
- ) -> List[int]:
- """
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
- special tokens using the tokenizer `prepare_for_model` method.
- Args:
- token_ids_0 (`List[int]`):
- List of IDs.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
- Whether or not the token list is already formatted with special tokens for the model.
- Returns:
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
- """
- if already_has_special_tokens:
- return super().get_special_tokens_mask(
- token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
- )
- prefix_ones = [1] * len(self.prefix_tokens)
- suffix_ones = [1] * len(self.suffix_tokens)
- if token_ids_1 is None:
- return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
- return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
- def build_inputs_with_special_tokens(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
- ) -> List[int]:
- """
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
- adding special tokens. An PLBART sequence has the following format, where `X` represents the sequence:
- - `input_ids` (for encoder) `X [eos, src_lang_code]`
- - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]`
- BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
- separator.
- Args:
- token_ids_0 (`List[int]`):
- List of IDs to which the special tokens will be added.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
- Returns:
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
- """
- if token_ids_1 is None:
- return self.prefix_tokens + token_ids_0 + self.suffix_tokens
- # We don't expect to process pairs, but leave the pair logic for API consistency
- return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
- def create_token_type_ids_from_sequences(
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
- ) -> List[int]:
- """
- Create a mask from the two sequences passed to be used in a sequence-pair classification task. PLBart does not
- make use of token type ids, therefore a list of zeros is returned.
- Args:
- token_ids_0 (`List[int]`):
- List of IDs.
- token_ids_1 (`List[int]`, *optional*):
- Optional second list of IDs for sequence pairs.
- Returns:
- `List[int]`: List of zeros.
- """
- sep = [self.sep_token_id]
- cls = [self.cls_token_id]
- if token_ids_1 is None:
- return len(cls + token_ids_0 + sep) * [0]
- return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
- def _build_translation_inputs(
- self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
- ):
- """Used by translation pipeline, to prepare inputs for the generate function"""
- if src_lang is None or tgt_lang is None:
- raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
- self.src_lang = self._convert_lang_code_special_format(src_lang)
- self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
- inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
- tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang)
- inputs["forced_bos_token_id"] = tgt_lang_id
- return inputs
- def get_vocab(self):
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
- vocab.update(self.added_tokens_encoder)
- return vocab
- def _tokenize(self, text: str) -> List[str]:
- return self.sp_model.encode(text, out_type=str)
- def _convert_token_to_id(self, token):
- """Converts a token (str) in an id using the vocab."""
- if token in self.fairseq_tokens_to_ids:
- return self.fairseq_tokens_to_ids[token]
- spm_id = self.sp_model.PieceToId(token)
- # Need to return unknown token if the SP model returned 0
- return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
- def _convert_id_to_token(self, index):
- """Converts an index (integer) in a token (str) using the vocab."""
- if index in self.fairseq_ids_to_tokens:
- return self.fairseq_ids_to_tokens[index]
- return self.sp_model.IdToPiece(index - self.fairseq_offset)
- def convert_tokens_to_string(self, tokens):
- """Converts a sequence of tokens (strings for sub-words) in a single string."""
- out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
- return out_string
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
- if not os.path.isdir(save_directory):
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
- return
- out_vocab_file = os.path.join(
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
- )
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
- copyfile(self.vocab_file, out_vocab_file)
- elif not os.path.isfile(self.vocab_file):
- with open(out_vocab_file, "wb") as fi:
- content_spiece_model = self.sp_model.serialized_model_proto()
- fi.write(content_spiece_model)
- return (out_vocab_file,)
- def prepare_seq2seq_batch(
- self,
- src_texts: List[str],
- src_lang: str = "en_XX",
- tgt_texts: Optional[List[str]] = None,
- tgt_lang: str = "python",
- **kwargs,
- ) -> BatchEncoding:
- self.src_lang = self._convert_lang_code_special_format(src_lang)
- self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
- return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
- def _switch_to_input_mode(self):
- return self.set_src_lang_special_tokens(self.src_lang)
- def _switch_to_target_mode(self):
- return self.set_tgt_lang_special_tokens(self.tgt_lang)
- def set_src_lang_special_tokens(self, src_lang) -> None:
- """Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
- src_lang = self._convert_lang_code_special_format(src_lang)
- self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
- self.prefix_tokens = []
- if self.cur_lang_code is not None:
- self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
- else:
- self.suffix_tokens = [self.eos_token_id]
- def set_tgt_lang_special_tokens(self, lang: str) -> None:
- """Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
- lang = self._convert_lang_code_special_format(lang)
- self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
- self.prefix_tokens = []
- if self.cur_lang_code is not None:
- self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
- else:
- self.suffix_tokens = [self.eos_token_id]
- def _convert_lang_code_special_format(self, lang: str) -> str:
- """Convert Language Codes to format tokenizer uses if required"""
- lang = FAIRSEQ_LANGUAGE_CODES_MAP[lang] if lang in FAIRSEQ_LANGUAGE_CODES_MAP.keys() else lang
- return lang
|