tokenization_bert_tf.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import os
  2. from typing import List, Union
  3. import tensorflow as tf
  4. from tensorflow_text import BertTokenizer as BertTokenizerLayer
  5. from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs
  6. from ...modeling_tf_utils import keras
  7. from .tokenization_bert import BertTokenizer
  8. class TFBertTokenizer(keras.layers.Layer):
  9. """
  10. This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the
  11. `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings
  12. from an existing standard tokenizer object.
  13. In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run
  14. when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options
  15. than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes
  16. straight from `tf.string` inputs to outputs.
  17. Args:
  18. vocab_list (`list`):
  19. List containing the vocabulary.
  20. do_lower_case (`bool`, *optional*, defaults to `True`):
  21. Whether or not to lowercase the input when tokenizing.
  22. cls_token_id (`str`, *optional*, defaults to `"[CLS]"`):
  23. The classifier token which is used when doing sequence classification (classification of the whole sequence
  24. instead of per-token classification). It is the first token of the sequence when built with special tokens.
  25. sep_token_id (`str`, *optional*, defaults to `"[SEP]"`):
  26. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  27. sequence classification or for a text and a question for question answering. It is also used as the last
  28. token of a sequence built with special tokens.
  29. pad_token_id (`str`, *optional*, defaults to `"[PAD]"`):
  30. The token used for padding, for example when batching sequences of different lengths.
  31. padding (`str`, defaults to `"longest"`):
  32. The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch,
  33. or `"max_length", to pad all inputs to the maximum length supported by the tokenizer.
  34. truncation (`bool`, *optional*, defaults to `True`):
  35. Whether to truncate the sequence to the maximum length.
  36. max_length (`int`, *optional*, defaults to `512`):
  37. The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if
  38. `truncation` is `True`).
  39. pad_to_multiple_of (`int`, *optional*, defaults to `None`):
  40. If set, the sequence will be padded to a multiple of this value.
  41. return_token_type_ids (`bool`, *optional*, defaults to `True`):
  42. Whether to return token_type_ids.
  43. return_attention_mask (`bool`, *optional*, defaults to `True`):
  44. Whether to return the attention_mask.
  45. use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
  46. If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
  47. class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
  48. TFLite.
  49. """
  50. def __init__(
  51. self,
  52. vocab_list: List,
  53. do_lower_case: bool,
  54. cls_token_id: int = None,
  55. sep_token_id: int = None,
  56. pad_token_id: int = None,
  57. padding: str = "longest",
  58. truncation: bool = True,
  59. max_length: int = 512,
  60. pad_to_multiple_of: int = None,
  61. return_token_type_ids: bool = True,
  62. return_attention_mask: bool = True,
  63. use_fast_bert_tokenizer: bool = True,
  64. **tokenizer_kwargs,
  65. ):
  66. super().__init__()
  67. if use_fast_bert_tokenizer:
  68. self.tf_tokenizer = FastBertTokenizer(
  69. vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
  70. )
  71. else:
  72. lookup_table = tf.lookup.StaticVocabularyTable(
  73. tf.lookup.KeyValueTensorInitializer(
  74. keys=vocab_list,
  75. key_dtype=tf.string,
  76. values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64),
  77. value_dtype=tf.int64,
  78. ),
  79. num_oov_buckets=1,
  80. )
  81. self.tf_tokenizer = BertTokenizerLayer(
  82. lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
  83. )
  84. self.vocab_list = vocab_list
  85. self.do_lower_case = do_lower_case
  86. self.cls_token_id = vocab_list.index("[CLS]") if cls_token_id is None else cls_token_id
  87. self.sep_token_id = vocab_list.index("[SEP]") if sep_token_id is None else sep_token_id
  88. self.pad_token_id = vocab_list.index("[PAD]") if pad_token_id is None else pad_token_id
  89. self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens
  90. self.max_length = max_length
  91. self.padding = padding
  92. self.truncation = truncation
  93. self.pad_to_multiple_of = pad_to_multiple_of
  94. self.return_token_type_ids = return_token_type_ids
  95. self.return_attention_mask = return_attention_mask
  96. @classmethod
  97. def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821
  98. """
  99. Initialize a `TFBertTokenizer` from an existing `Tokenizer`.
  100. Args:
  101. tokenizer (`PreTrainedTokenizerBase`):
  102. The tokenizer to use to initialize the `TFBertTokenizer`.
  103. Examples:
  104. ```python
  105. from transformers import AutoTokenizer, TFBertTokenizer
  106. tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  107. tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer)
  108. ```
  109. """
  110. do_lower_case = kwargs.pop("do_lower_case", None)
  111. do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case
  112. cls_token_id = kwargs.pop("cls_token_id", None)
  113. cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id
  114. sep_token_id = kwargs.pop("sep_token_id", None)
  115. sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id
  116. pad_token_id = kwargs.pop("pad_token_id", None)
  117. pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id
  118. vocab = tokenizer.get_vocab()
  119. vocab = sorted(vocab.items(), key=lambda x: x[1])
  120. vocab_list = [entry[0] for entry in vocab]
  121. return cls(
  122. vocab_list=vocab_list,
  123. do_lower_case=do_lower_case,
  124. cls_token_id=cls_token_id,
  125. sep_token_id=sep_token_id,
  126. pad_token_id=pad_token_id,
  127. **kwargs,
  128. )
  129. @classmethod
  130. def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
  131. """
  132. Instantiate a `TFBertTokenizer` from a pre-trained tokenizer.
  133. Args:
  134. pretrained_model_name_or_path (`str` or `os.PathLike`):
  135. The name or path to the pre-trained tokenizer.
  136. Examples:
  137. ```python
  138. from transformers import TFBertTokenizer
  139. tf_tokenizer = TFBertTokenizer.from_pretrained("google-bert/bert-base-uncased")
  140. ```
  141. """
  142. try:
  143. tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
  144. except: # noqa: E722
  145. from .tokenization_bert_fast import BertTokenizerFast
  146. tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
  147. return cls.from_tokenizer(tokenizer, **kwargs)
  148. def unpaired_tokenize(self, texts):
  149. if self.do_lower_case:
  150. texts = case_fold_utf8(texts)
  151. tokens = self.tf_tokenizer.tokenize(texts)
  152. return tokens.merge_dims(1, -1)
  153. def call(
  154. self,
  155. text,
  156. text_pair=None,
  157. padding=None,
  158. truncation=None,
  159. max_length=None,
  160. pad_to_multiple_of=None,
  161. return_token_type_ids=None,
  162. return_attention_mask=None,
  163. ):
  164. if padding is None:
  165. padding = self.padding
  166. if padding not in ("longest", "max_length"):
  167. raise ValueError("Padding must be either 'longest' or 'max_length'!")
  168. if max_length is not None and text_pair is not None:
  169. # Because we have to instantiate a Trimmer to do it properly
  170. raise ValueError("max_length cannot be overridden at call time when truncating paired texts!")
  171. if max_length is None:
  172. max_length = self.max_length
  173. if truncation is None:
  174. truncation = self.truncation
  175. if pad_to_multiple_of is None:
  176. pad_to_multiple_of = self.pad_to_multiple_of
  177. if return_token_type_ids is None:
  178. return_token_type_ids = self.return_token_type_ids
  179. if return_attention_mask is None:
  180. return_attention_mask = self.return_attention_mask
  181. if not isinstance(text, tf.Tensor):
  182. text = tf.convert_to_tensor(text)
  183. if text_pair is not None and not isinstance(text_pair, tf.Tensor):
  184. text_pair = tf.convert_to_tensor(text_pair)
  185. if text_pair is not None:
  186. if text.shape.rank > 1:
  187. raise ValueError("text argument should not be multidimensional when a text pair is supplied!")
  188. if text_pair.shape.rank > 1:
  189. raise ValueError("text_pair should not be multidimensional!")
  190. if text.shape.rank == 2:
  191. text, text_pair = text[:, 0], text[:, 1]
  192. text = self.unpaired_tokenize(text)
  193. if text_pair is None: # Unpaired text
  194. if truncation:
  195. text = text[:, : max_length - 2] # Allow room for special tokens
  196. input_ids, token_type_ids = combine_segments(
  197. (text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id
  198. )
  199. else: # Paired text
  200. text_pair = self.unpaired_tokenize(text_pair)
  201. if truncation:
  202. text, text_pair = self.paired_trimmer.trim([text, text_pair])
  203. input_ids, token_type_ids = combine_segments(
  204. (text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id
  205. )
  206. if padding == "longest":
  207. pad_length = input_ids.bounding_shape(axis=1)
  208. if pad_to_multiple_of is not None:
  209. # No ceiling division in tensorflow, so we negate floordiv instead
  210. pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of))
  211. else:
  212. pad_length = max_length
  213. input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id)
  214. output = {"input_ids": input_ids}
  215. if return_attention_mask:
  216. output["attention_mask"] = attention_mask
  217. if return_token_type_ids:
  218. token_type_ids, _ = pad_model_inputs(
  219. token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id
  220. )
  221. output["token_type_ids"] = token_type_ids
  222. return output
  223. def get_config(self):
  224. return {
  225. "vocab_list": self.vocab_list,
  226. "do_lower_case": self.do_lower_case,
  227. "cls_token_id": self.cls_token_id,
  228. "sep_token_id": self.sep_token_id,
  229. "pad_token_id": self.pad_token_id,
  230. }