tokenization_esm.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # coding=utf-8
  2. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for ESM."""
  16. import os
  17. from typing import List, Optional
  18. from ...tokenization_utils import PreTrainedTokenizer
  19. from ...utils import logging
  20. logger = logging.get_logger(__name__)
  21. VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
  22. def load_vocab_file(vocab_file):
  23. with open(vocab_file, "r") as f:
  24. lines = f.read().splitlines()
  25. return [l.strip() for l in lines]
  26. class EsmTokenizer(PreTrainedTokenizer):
  27. """
  28. Constructs an ESM tokenizer.
  29. """
  30. vocab_files_names = VOCAB_FILES_NAMES
  31. model_input_names = ["input_ids", "attention_mask"]
  32. def __init__(
  33. self,
  34. vocab_file,
  35. unk_token="<unk>",
  36. cls_token="<cls>",
  37. pad_token="<pad>",
  38. mask_token="<mask>",
  39. eos_token="<eos>",
  40. **kwargs,
  41. ):
  42. self.all_tokens = load_vocab_file(vocab_file)
  43. self._id_to_token = dict(enumerate(self.all_tokens))
  44. self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
  45. super().__init__(
  46. unk_token=unk_token,
  47. cls_token=cls_token,
  48. pad_token=pad_token,
  49. mask_token=mask_token,
  50. eos_token=eos_token,
  51. **kwargs,
  52. )
  53. # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
  54. # none of them are special, but they all need special splitting.
  55. self.unique_no_split_tokens = self.all_tokens
  56. self._update_trie(self.unique_no_split_tokens)
  57. def _convert_id_to_token(self, index: int) -> str:
  58. return self._id_to_token.get(index, self.unk_token)
  59. def _convert_token_to_id(self, token: str) -> int:
  60. return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
  61. def _tokenize(self, text, **kwargs):
  62. return text.split()
  63. def get_vocab(self):
  64. base_vocab = self._token_to_id.copy()
  65. base_vocab.update(self.added_tokens_encoder)
  66. return base_vocab
  67. def token_to_id(self, token: str) -> int:
  68. return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
  69. def id_to_token(self, index: int) -> str:
  70. return self._id_to_token.get(index, self.unk_token)
  71. def build_inputs_with_special_tokens(
  72. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  73. ) -> List[int]:
  74. cls = [self.cls_token_id]
  75. sep = [self.eos_token_id] # No sep token in ESM vocabulary
  76. if token_ids_1 is None:
  77. if self.eos_token_id is None:
  78. return cls + token_ids_0
  79. else:
  80. return cls + token_ids_0 + sep
  81. elif self.eos_token_id is None:
  82. raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
  83. return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
  84. def get_special_tokens_mask(
  85. self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
  86. ) -> List[int]:
  87. """
  88. Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
  89. special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
  90. Args:
  91. token_ids_0 (`List[int]`):
  92. List of ids of the first sequence.
  93. token_ids_1 (`List[int]`, *optional*):
  94. List of ids of the second sequence.
  95. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  96. Whether or not the token list is already formatted with special tokens for the model.
  97. Returns:
  98. A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  99. """
  100. if already_has_special_tokens:
  101. if token_ids_1 is not None:
  102. raise ValueError(
  103. "You should not supply a second sequence if the provided sequence of "
  104. "ids is already formatted with special tokens for the model."
  105. )
  106. return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
  107. mask = [1] + ([0] * len(token_ids_0)) + [1]
  108. if token_ids_1 is not None:
  109. mask += [0] * len(token_ids_1) + [1]
  110. return mask
  111. def save_vocabulary(self, save_directory, filename_prefix):
  112. vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
  113. with open(vocab_file, "w") as f:
  114. f.write("\n".join(self.all_tokens))
  115. return (vocab_file,)
  116. @property
  117. def vocab_size(self) -> int:
  118. return len(self.all_tokens)