data_collator.py 80 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. import warnings
  16. from collections.abc import Mapping
  17. from dataclasses import dataclass
  18. from random import randint
  19. from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
  20. import numpy as np
  21. from ..models.bert import BertTokenizer, BertTokenizerFast
  22. from ..tokenization_utils_base import PreTrainedTokenizerBase
  23. from ..utils import PaddingStrategy
  24. InputDataClass = NewType("InputDataClass", Any)
  25. """
  26. A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
  27. of PyTorch/TensorFlow tensors or NumPy arrays.
  28. """
  29. DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
  30. class DataCollatorMixin:
  31. def __call__(self, features, return_tensors=None):
  32. if return_tensors is None:
  33. return_tensors = self.return_tensors
  34. if return_tensors == "tf":
  35. return self.tf_call(features)
  36. elif return_tensors == "pt":
  37. return self.torch_call(features)
  38. elif return_tensors == "np":
  39. return self.numpy_call(features)
  40. else:
  41. raise ValueError(f"Framework '{return_tensors}' not recognized!")
  42. def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
  43. """
  44. Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
  45. """
  46. # To avoid errors when using Feature extractors
  47. if not hasattr(tokenizer, "deprecation_warnings"):
  48. return tokenizer.pad(*pad_args, **pad_kwargs)
  49. # Save the state of the warning, then disable it
  50. warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
  51. tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
  52. try:
  53. padded = tokenizer.pad(*pad_args, **pad_kwargs)
  54. finally:
  55. # Restore the state of the warning.
  56. tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
  57. return padded
  58. def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
  59. """
  60. Very simple data collator that simply collates batches of dict-like objects and performs special handling for
  61. potential keys named:
  62. - `label`: handles a single value (int or float) per object
  63. - `label_ids`: handles a list of values per object
  64. Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
  65. to the model. See glue and ner for example of how it's useful.
  66. """
  67. # In this function we'll make the assumption that all `features` in the batch
  68. # have the same attributes.
  69. # So we will look at the first element as a proxy for what attributes exist
  70. # on the whole batch.
  71. if return_tensors == "pt":
  72. return torch_default_data_collator(features)
  73. elif return_tensors == "tf":
  74. return tf_default_data_collator(features)
  75. elif return_tensors == "np":
  76. return numpy_default_data_collator(features)
  77. @dataclass
  78. class DefaultDataCollator(DataCollatorMixin):
  79. """
  80. Very simple data collator that simply collates batches of dict-like objects and performs special handling for
  81. potential keys named:
  82. - `label`: handles a single value (int or float) per object
  83. - `label_ids`: handles a list of values per object
  84. Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
  85. to the model. See glue and ner for example of how it's useful.
  86. This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
  87. helpful if you need to set a return_tensors value at initialization.
  88. Args:
  89. return_tensors (`str`, *optional*, defaults to `"pt"`):
  90. The type of Tensor to return. Allowable values are "np", "pt" and "tf".
  91. """
  92. return_tensors: str = "pt"
  93. def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
  94. if return_tensors is None:
  95. return_tensors = self.return_tensors
  96. return default_data_collator(features, return_tensors)
  97. def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
  98. import torch
  99. if not isinstance(features[0], Mapping):
  100. features = [vars(f) for f in features]
  101. first = features[0]
  102. batch = {}
  103. # Special handling for labels.
  104. # Ensure that tensor is created with the correct type
  105. # (it should be automatically the case, but let's make sure of it.)
  106. if "label" in first and first["label"] is not None:
  107. label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
  108. dtype = torch.long if isinstance(label, int) else torch.float
  109. batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
  110. elif "label_ids" in first and first["label_ids"] is not None:
  111. if isinstance(first["label_ids"], torch.Tensor):
  112. batch["labels"] = torch.stack([f["label_ids"] for f in features])
  113. else:
  114. dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
  115. batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
  116. # Handling of all other possible keys.
  117. # Again, we will use the first element to figure out which key/values are not None for this model.
  118. for k, v in first.items():
  119. if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
  120. if isinstance(v, torch.Tensor):
  121. batch[k] = torch.stack([f[k] for f in features])
  122. elif isinstance(v, np.ndarray):
  123. batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
  124. else:
  125. batch[k] = torch.tensor([f[k] for f in features])
  126. return batch
  127. def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
  128. import tensorflow as tf
  129. if not isinstance(features[0], Mapping):
  130. features = [vars(f) for f in features]
  131. first = features[0]
  132. batch = {}
  133. # Special handling for labels.
  134. # Ensure that tensor is created with the correct type
  135. # (it should be automatically the case, but let's make sure of it.)
  136. if "label" in first and first["label"] is not None:
  137. label_col_name = "label"
  138. elif "label_ids" in first and first["label_ids"] is not None:
  139. label_col_name = "label_ids"
  140. elif "labels" in first and first["labels"] is not None:
  141. label_col_name = "labels"
  142. else:
  143. label_col_name = None
  144. if label_col_name is not None:
  145. if isinstance(first[label_col_name], tf.Tensor):
  146. dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
  147. elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
  148. dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
  149. elif isinstance(first[label_col_name], (tuple, list)):
  150. dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
  151. else:
  152. dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
  153. batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
  154. # Handling of all other possible keys.
  155. # Again, we will use the first element to figure out which key/values are not None for this model.
  156. for k, v in first.items():
  157. if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
  158. if isinstance(v, (tf.Tensor, np.ndarray)):
  159. batch[k] = tf.stack([f[k] for f in features])
  160. else:
  161. batch[k] = tf.convert_to_tensor([f[k] for f in features])
  162. return batch
  163. def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
  164. if not isinstance(features[0], Mapping):
  165. features = [vars(f) for f in features]
  166. first = features[0]
  167. batch = {}
  168. # Special handling for labels.
  169. # Ensure that tensor is created with the correct type
  170. # (it should be automatically the case, but let's make sure of it.)
  171. if "label" in first and first["label"] is not None:
  172. label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
  173. dtype = np.int64 if isinstance(label, int) else np.float32
  174. batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
  175. elif "label_ids" in first and first["label_ids"] is not None:
  176. if isinstance(first["label_ids"], np.ndarray):
  177. batch["labels"] = np.stack([f["label_ids"] for f in features])
  178. else:
  179. dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
  180. batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
  181. # Handling of all other possible keys.
  182. # Again, we will use the first element to figure out which key/values are not None for this model.
  183. for k, v in first.items():
  184. if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
  185. if isinstance(v, np.ndarray):
  186. batch[k] = np.stack([f[k] for f in features])
  187. else:
  188. batch[k] = np.array([f[k] for f in features])
  189. return batch
  190. @dataclass
  191. class DataCollatorWithPadding:
  192. """
  193. Data collator that will dynamically pad the inputs received.
  194. Args:
  195. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  196. The tokenizer used for encoding the data.
  197. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  198. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  199. among:
  200. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  201. sequence is provided).
  202. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  203. acceptable input length for the model if that argument is not provided.
  204. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  205. max_length (`int`, *optional*):
  206. Maximum length of the returned list and optionally padding length (see above).
  207. pad_to_multiple_of (`int`, *optional*):
  208. If set will pad the sequence to a multiple of the provided value.
  209. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  210. 7.5 (Volta).
  211. return_tensors (`str`, *optional*, defaults to `"pt"`):
  212. The type of Tensor to return. Allowable values are "np", "pt" and "tf".
  213. """
  214. tokenizer: PreTrainedTokenizerBase
  215. padding: Union[bool, str, PaddingStrategy] = True
  216. max_length: Optional[int] = None
  217. pad_to_multiple_of: Optional[int] = None
  218. return_tensors: str = "pt"
  219. def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
  220. batch = pad_without_fast_tokenizer_warning(
  221. self.tokenizer,
  222. features,
  223. padding=self.padding,
  224. max_length=self.max_length,
  225. pad_to_multiple_of=self.pad_to_multiple_of,
  226. return_tensors=self.return_tensors,
  227. )
  228. if "label" in batch:
  229. batch["labels"] = batch["label"]
  230. del batch["label"]
  231. if "label_ids" in batch:
  232. batch["labels"] = batch["label_ids"]
  233. del batch["label_ids"]
  234. return batch
  235. @dataclass
  236. class DataCollatorForTokenClassification(DataCollatorMixin):
  237. """
  238. Data collator that will dynamically pad the inputs received, as well as the labels.
  239. Args:
  240. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  241. The tokenizer used for encoding the data.
  242. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  243. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  244. among:
  245. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  246. sequence is provided).
  247. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  248. acceptable input length for the model if that argument is not provided.
  249. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  250. max_length (`int`, *optional*):
  251. Maximum length of the returned list and optionally padding length (see above).
  252. pad_to_multiple_of (`int`, *optional*):
  253. If set will pad the sequence to a multiple of the provided value.
  254. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  255. 7.5 (Volta).
  256. label_pad_token_id (`int`, *optional*, defaults to -100):
  257. The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
  258. return_tensors (`str`, *optional*, defaults to `"pt"`):
  259. The type of Tensor to return. Allowable values are "np", "pt" and "tf".
  260. """
  261. tokenizer: PreTrainedTokenizerBase
  262. padding: Union[bool, str, PaddingStrategy] = True
  263. max_length: Optional[int] = None
  264. pad_to_multiple_of: Optional[int] = None
  265. label_pad_token_id: int = -100
  266. return_tensors: str = "pt"
  267. def torch_call(self, features):
  268. import torch
  269. label_name = "label" if "label" in features[0].keys() else "labels"
  270. labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
  271. no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
  272. batch = pad_without_fast_tokenizer_warning(
  273. self.tokenizer,
  274. no_labels_features,
  275. padding=self.padding,
  276. max_length=self.max_length,
  277. pad_to_multiple_of=self.pad_to_multiple_of,
  278. return_tensors="pt",
  279. )
  280. if labels is None:
  281. return batch
  282. sequence_length = batch["input_ids"].shape[1]
  283. padding_side = self.tokenizer.padding_side
  284. def to_list(tensor_or_iterable):
  285. if isinstance(tensor_or_iterable, torch.Tensor):
  286. return tensor_or_iterable.tolist()
  287. return list(tensor_or_iterable)
  288. if padding_side == "right":
  289. batch[label_name] = [
  290. to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
  291. ]
  292. else:
  293. batch[label_name] = [
  294. [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
  295. ]
  296. batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
  297. return batch
  298. def tf_call(self, features):
  299. import tensorflow as tf
  300. label_name = "label" if "label" in features[0].keys() else "labels"
  301. labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
  302. batch = pad_without_fast_tokenizer_warning(
  303. self.tokenizer,
  304. features,
  305. padding=self.padding,
  306. max_length=self.max_length,
  307. pad_to_multiple_of=self.pad_to_multiple_of,
  308. # Conversion to tensors will fail if we have labels as they are not of the same length yet.
  309. return_tensors="tf" if labels is None else None,
  310. )
  311. if labels is None:
  312. return batch
  313. sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
  314. padding_side = self.tokenizer.padding_side
  315. if padding_side == "right":
  316. batch["labels"] = [
  317. list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
  318. ]
  319. else:
  320. batch["labels"] = [
  321. [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
  322. ]
  323. batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
  324. return batch
  325. def numpy_call(self, features):
  326. label_name = "label" if "label" in features[0].keys() else "labels"
  327. labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
  328. batch = pad_without_fast_tokenizer_warning(
  329. self.tokenizer,
  330. features,
  331. padding=self.padding,
  332. max_length=self.max_length,
  333. pad_to_multiple_of=self.pad_to_multiple_of,
  334. # Conversion to tensors will fail if we have labels as they are not of the same length yet.
  335. return_tensors="np" if labels is None else None,
  336. )
  337. if labels is None:
  338. return batch
  339. sequence_length = np.array(batch["input_ids"]).shape[1]
  340. padding_side = self.tokenizer.padding_side
  341. if padding_side == "right":
  342. batch["labels"] = [
  343. list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
  344. ]
  345. else:
  346. batch["labels"] = [
  347. [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
  348. ]
  349. batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
  350. return batch
  351. def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
  352. """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
  353. import torch
  354. # Tensorize if necessary.
  355. if isinstance(examples[0], (list, tuple, np.ndarray)):
  356. examples = [torch.tensor(e, dtype=torch.long) for e in examples]
  357. length_of_first = examples[0].size(0)
  358. # Check if padding is necessary.
  359. are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
  360. if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
  361. if not isinstance(examples, torch.Tensor):
  362. return torch.stack(examples, dim=0)
  363. # If yes, check if we have a `pad_token`.
  364. if tokenizer._pad_token is None:
  365. raise ValueError(
  366. "You are attempting to pad samples but the tokenizer you are using"
  367. f" ({tokenizer.__class__.__name__}) does not have a pad token."
  368. )
  369. # Creating the full tensor and filling it with our data.
  370. max_length = max(x.size(0) for x in examples)
  371. if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  372. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  373. result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
  374. for i, example in enumerate(examples):
  375. if tokenizer.padding_side == "right":
  376. result[i, : example.shape[0]] = example
  377. else:
  378. result[i, -example.shape[0] :] = example
  379. return result
  380. def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
  381. import tensorflow as tf
  382. """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
  383. # Tensorize if necessary.
  384. if isinstance(examples[0], (list, tuple)):
  385. examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
  386. # Check if padding is necessary.
  387. length_of_first = len(examples[0])
  388. are_tensors_same_length = all(len(x) == length_of_first for x in examples)
  389. if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
  390. return tf.stack(examples, axis=0)
  391. # If yes, check if we have a `pad_token`.
  392. if tokenizer._pad_token is None:
  393. raise ValueError(
  394. "You are attempting to pad samples but the tokenizer you are using"
  395. f" ({tokenizer.__class__.__name__}) does not have a pad token."
  396. )
  397. # Creating the full tensor and filling it with our data.
  398. max_length = max(len(x) for x in examples)
  399. if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  400. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  401. # result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
  402. result = []
  403. rank = tf.rank(examples[0])
  404. paddings = np.zeros((rank, 2), dtype=np.int32)
  405. for example in examples:
  406. if tokenizer.padding_side == "right":
  407. paddings[0, 1] = max_length - len(example)
  408. else:
  409. paddings[0, 0] = max_length - len(example)
  410. result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
  411. return tf.stack(result, axis=0)
  412. def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
  413. """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
  414. # Tensorize if necessary.
  415. if isinstance(examples[0], (list, tuple)):
  416. examples = [np.array(e, dtype=np.int64) for e in examples]
  417. # Check if padding is necessary.
  418. length_of_first = len(examples[0])
  419. are_tensors_same_length = all(len(x) == length_of_first for x in examples)
  420. if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
  421. return np.stack(examples, axis=0)
  422. # If yes, check if we have a `pad_token`.
  423. if tokenizer._pad_token is None:
  424. raise ValueError(
  425. "You are attempting to pad samples but the tokenizer you are using"
  426. f" ({tokenizer.__class__.__name__}) does not have a pad token."
  427. )
  428. # Creating the full tensor and filling it with our data.
  429. max_length = max(len(x) for x in examples)
  430. if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  431. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  432. result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
  433. for i, example in enumerate(examples):
  434. if tokenizer.padding_side == "right":
  435. result[i, : example.shape[0]] = example
  436. else:
  437. result[i, -example.shape[0] :] = example
  438. return result
  439. def tolist(x):
  440. if isinstance(x, list):
  441. return x
  442. elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
  443. x = x.numpy()
  444. return x.tolist()
  445. @dataclass
  446. class DataCollatorForSeq2Seq:
  447. """
  448. Data collator that will dynamically pad the inputs received, as well as the labels.
  449. Args:
  450. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  451. The tokenizer used for encoding the data.
  452. model ([`PreTrainedModel`], *optional*):
  453. The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
  454. prepare the *decoder_input_ids*
  455. This is useful when using *label_smoothing* to avoid calculating loss twice.
  456. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  457. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  458. among:
  459. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  460. sequence is provided).
  461. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  462. acceptable input length for the model if that argument is not provided.
  463. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  464. max_length (`int`, *optional*):
  465. Maximum length of the returned list and optionally padding length (see above).
  466. pad_to_multiple_of (`int`, *optional*):
  467. If set will pad the sequence to a multiple of the provided value.
  468. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  469. 7.5 (Volta).
  470. label_pad_token_id (`int`, *optional*, defaults to -100):
  471. The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
  472. return_tensors (`str`, *optional*, defaults to `"pt"`):
  473. The type of Tensor to return. Allowable values are "np", "pt" and "tf".
  474. """
  475. tokenizer: PreTrainedTokenizerBase
  476. model: Optional[Any] = None
  477. padding: Union[bool, str, PaddingStrategy] = True
  478. max_length: Optional[int] = None
  479. pad_to_multiple_of: Optional[int] = None
  480. label_pad_token_id: int = -100
  481. return_tensors: str = "pt"
  482. def __call__(self, features, return_tensors=None):
  483. if return_tensors is None:
  484. return_tensors = self.return_tensors
  485. label_name = "label" if "label" in features[0].keys() else "labels"
  486. labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
  487. # reconvert list[None] to None if necessary
  488. # this might occur when we pass {..., "labels": None}
  489. if labels is not None and all(label is None for label in labels):
  490. labels = None
  491. non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
  492. # run through tokenizer without labels to ensure no side effects
  493. batch = pad_without_fast_tokenizer_warning(
  494. self.tokenizer,
  495. non_labels_features,
  496. padding=self.padding,
  497. max_length=self.max_length,
  498. pad_to_multiple_of=self.pad_to_multiple_of,
  499. return_tensors=return_tensors,
  500. )
  501. # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
  502. no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
  503. if labels is not None:
  504. if no_padding:
  505. if isinstance(features[0][label_name], list):
  506. batch["labels"] = list(labels)
  507. else:
  508. batch["labels"] = [np.concatenate([label, []]) for label in labels]
  509. else:
  510. max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
  511. max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
  512. if self.pad_to_multiple_of is not None:
  513. max_label_length = (
  514. (max_label_length + self.pad_to_multiple_of - 1)
  515. // self.pad_to_multiple_of
  516. * self.pad_to_multiple_of
  517. )
  518. padding_side = self.tokenizer.padding_side
  519. if isinstance(features[0][label_name], list):
  520. batch["labels"] = [
  521. label + [self.label_pad_token_id] * (max_label_length - len(label))
  522. if padding_side == "right"
  523. else [self.label_pad_token_id] * (max_label_length - len(label)) + label
  524. for label in labels
  525. ]
  526. else:
  527. batch["labels"] = [
  528. np.concatenate(
  529. [
  530. label,
  531. np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
  532. ]
  533. )
  534. if padding_side == "right"
  535. else np.concatenate(
  536. [
  537. np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
  538. label,
  539. ]
  540. )
  541. for label in labels
  542. ]
  543. # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
  544. if batch.get("labels", None) is not None:
  545. if return_tensors == "pt":
  546. import torch
  547. batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
  548. elif return_tensors == "tf":
  549. import tensorflow as tf
  550. batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
  551. else:
  552. batch["labels"] = np.array(batch["labels"], dtype=np.int64)
  553. else:
  554. batch["labels"] = None
  555. # prepare decoder_input_ids
  556. if (
  557. labels is not None
  558. and self.model is not None
  559. and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
  560. ):
  561. decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
  562. batch["decoder_input_ids"] = decoder_input_ids
  563. return batch
  564. @dataclass
  565. class DataCollatorForLanguageModeling(DataCollatorMixin):
  566. """
  567. Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
  568. are not all of the same length.
  569. Args:
  570. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  571. The tokenizer used for encoding the data.
  572. mlm (`bool`, *optional*, defaults to `True`):
  573. Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
  574. with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
  575. tokens and the value to predict for the masked token.
  576. mlm_probability (`float`, *optional*, defaults to 0.15):
  577. The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
  578. pad_to_multiple_of (`int`, *optional*):
  579. If set will pad the sequence to a multiple of the provided value.
  580. return_tensors (`str`):
  581. The type of Tensor to return. Allowable values are "np", "pt" and "tf".
  582. <Tip>
  583. For best performance, this data collator should be used with a dataset having items that are dictionaries or
  584. BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
  585. [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
  586. </Tip>"""
  587. tokenizer: PreTrainedTokenizerBase
  588. mlm: bool = True
  589. mlm_probability: float = 0.15
  590. pad_to_multiple_of: Optional[int] = None
  591. tf_experimental_compile: bool = False
  592. return_tensors: str = "pt"
  593. def __post_init__(self):
  594. if self.mlm and self.tokenizer.mask_token is None:
  595. raise ValueError(
  596. "This tokenizer does not have a mask token which is necessary for masked language modeling. "
  597. "You should pass `mlm=False` to train on causal language modeling instead."
  598. )
  599. if self.tf_experimental_compile:
  600. import tensorflow as tf
  601. self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
  602. @staticmethod
  603. def tf_bernoulli(shape, probability):
  604. import tensorflow as tf
  605. prob_matrix = tf.fill(shape, probability)
  606. return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
  607. def tf_mask_tokens(
  608. self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
  609. ) -> Tuple[Any, Any]:
  610. """
  611. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
  612. """
  613. import tensorflow as tf
  614. mask_token_id = tf.cast(mask_token_id, inputs.dtype)
  615. input_shape = tf.shape(inputs)
  616. # 1 for a special token, 0 for a normal token in the special tokens mask
  617. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  618. masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
  619. # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
  620. labels = tf.where(masked_indices, inputs, -100)
  621. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  622. indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
  623. inputs = tf.where(indices_replaced, mask_token_id, inputs)
  624. # 10% of the time, we replace masked input tokens with random word
  625. indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
  626. random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
  627. inputs = tf.where(indices_random, random_words, inputs)
  628. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  629. return inputs, labels
  630. def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  631. import tensorflow as tf
  632. # Handle dict or lists with proper padding and conversion to tensor.
  633. if isinstance(examples[0], Mapping):
  634. batch = pad_without_fast_tokenizer_warning(
  635. self.tokenizer, examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of
  636. )
  637. else:
  638. batch = {
  639. "input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  640. }
  641. # If special token mask has been preprocessed, pop it from the dict.
  642. special_tokens_mask = batch.pop("special_tokens_mask", None)
  643. if self.mlm:
  644. if special_tokens_mask is None:
  645. special_tokens_mask = [
  646. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
  647. for val in batch["input_ids"].numpy().tolist()
  648. ]
  649. # Cannot directly create as bool
  650. special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
  651. else:
  652. special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
  653. batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
  654. tf.cast(batch["input_ids"], tf.int64),
  655. special_tokens_mask=special_tokens_mask,
  656. mask_token_id=self.tokenizer.mask_token_id,
  657. vocab_size=len(self.tokenizer),
  658. )
  659. else:
  660. labels = batch["input_ids"]
  661. if self.tokenizer.pad_token_id is not None:
  662. # Replace self.tokenizer.pad_token_id with -100
  663. labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
  664. else:
  665. labels = tf.identity(labels) # Makes a copy, just in case
  666. batch["labels"] = labels
  667. return batch
  668. def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  669. # Handle dict or lists with proper padding and conversion to tensor.
  670. if isinstance(examples[0], Mapping):
  671. batch = pad_without_fast_tokenizer_warning(
  672. self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
  673. )
  674. else:
  675. batch = {
  676. "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  677. }
  678. # If special token mask has been preprocessed, pop it from the dict.
  679. special_tokens_mask = batch.pop("special_tokens_mask", None)
  680. if self.mlm:
  681. batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
  682. batch["input_ids"], special_tokens_mask=special_tokens_mask
  683. )
  684. else:
  685. labels = batch["input_ids"].clone()
  686. if self.tokenizer.pad_token_id is not None:
  687. labels[labels == self.tokenizer.pad_token_id] = -100
  688. batch["labels"] = labels
  689. return batch
  690. def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
  691. """
  692. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
  693. """
  694. import torch
  695. labels = inputs.clone()
  696. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  697. probability_matrix = torch.full(labels.shape, self.mlm_probability)
  698. if special_tokens_mask is None:
  699. special_tokens_mask = [
  700. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  701. ]
  702. special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
  703. else:
  704. special_tokens_mask = special_tokens_mask.bool()
  705. probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
  706. masked_indices = torch.bernoulli(probability_matrix).bool()
  707. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  708. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  709. indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
  710. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  711. # 10% of the time, we replace masked input tokens with random word
  712. indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  713. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  714. inputs[indices_random] = random_words[indices_random]
  715. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  716. return inputs, labels
  717. def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  718. # Handle dict or lists with proper padding and conversion to tensor.
  719. if isinstance(examples[0], Mapping):
  720. batch = pad_without_fast_tokenizer_warning(
  721. self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
  722. )
  723. else:
  724. batch = {
  725. "input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  726. }
  727. # If special token mask has been preprocessed, pop it from the dict.
  728. special_tokens_mask = batch.pop("special_tokens_mask", None)
  729. if self.mlm:
  730. batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
  731. batch["input_ids"], special_tokens_mask=special_tokens_mask
  732. )
  733. else:
  734. labels = np.copy(batch["input_ids"])
  735. if self.tokenizer.pad_token_id is not None:
  736. labels[labels == self.tokenizer.pad_token_id] = -100
  737. batch["labels"] = labels
  738. return batch
  739. def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
  740. """
  741. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
  742. """
  743. labels = np.copy(inputs)
  744. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  745. probability_matrix = np.full(labels.shape, self.mlm_probability)
  746. if special_tokens_mask is None:
  747. special_tokens_mask = [
  748. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  749. ]
  750. special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
  751. else:
  752. special_tokens_mask = special_tokens_mask.astype(bool)
  753. probability_matrix[special_tokens_mask] = 0
  754. # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
  755. masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
  756. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  757. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  758. indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
  759. inputs[indices_replaced] = self.tokenizer.mask_token_id
  760. # 10% of the time, we replace masked input tokens with random word
  761. # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  762. indices_random = (
  763. np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
  764. )
  765. random_words = np.random.randint(
  766. low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
  767. )
  768. inputs[indices_random] = random_words
  769. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  770. return inputs, labels
  771. @dataclass
  772. class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
  773. """
  774. Data collator used for language modeling that masks entire words.
  775. - collates batches of tensors, honoring their tokenizer's pad_token
  776. - preprocesses batches for masked language modeling
  777. <Tip>
  778. This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
  779. that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
  780. produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
  781. </Tip>"""
  782. def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  783. if isinstance(examples[0], Mapping):
  784. input_ids = [e["input_ids"] for e in examples]
  785. else:
  786. input_ids = examples
  787. examples = [{"input_ids": e} for e in examples]
  788. batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  789. mask_labels = []
  790. for e in examples:
  791. ref_tokens = []
  792. for id in tolist(e["input_ids"]):
  793. token = self.tokenizer._convert_id_to_token(id)
  794. ref_tokens.append(token)
  795. # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
  796. if "chinese_ref" in e:
  797. ref_pos = tolist(e["chinese_ref"])
  798. len_seq = len(e["input_ids"])
  799. for i in range(len_seq):
  800. if i in ref_pos:
  801. ref_tokens[i] = "##" + ref_tokens[i]
  802. mask_labels.append(self._whole_word_mask(ref_tokens))
  803. batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  804. inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
  805. return {"input_ids": inputs, "labels": labels}
  806. def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  807. import tensorflow as tf
  808. if isinstance(examples[0], Mapping):
  809. input_ids = [e["input_ids"] for e in examples]
  810. else:
  811. input_ids = examples
  812. examples = [{"input_ids": e} for e in examples]
  813. batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  814. mask_labels = []
  815. for e in examples:
  816. ref_tokens = []
  817. for id in tolist(e["input_ids"]):
  818. token = self.tokenizer._convert_id_to_token(id)
  819. ref_tokens.append(token)
  820. # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
  821. if "chinese_ref" in e:
  822. ref_pos = tolist(e["chinese_ref"])
  823. len_seq = len(e["input_ids"])
  824. for i in range(len_seq):
  825. if i in ref_pos:
  826. ref_tokens[i] = "##" + ref_tokens[i]
  827. mask_labels.append(self._whole_word_mask(ref_tokens))
  828. batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  829. inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
  830. return {"input_ids": inputs, "labels": labels}
  831. def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  832. if isinstance(examples[0], Mapping):
  833. input_ids = [e["input_ids"] for e in examples]
  834. else:
  835. input_ids = examples
  836. examples = [{"input_ids": e} for e in examples]
  837. batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  838. mask_labels = []
  839. for e in examples:
  840. ref_tokens = []
  841. for id in tolist(e["input_ids"]):
  842. token = self.tokenizer._convert_id_to_token(id)
  843. ref_tokens.append(token)
  844. # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
  845. if "chinese_ref" in e:
  846. ref_pos = tolist(e["chinese_ref"])
  847. len_seq = len(e["input_ids"])
  848. for i in range(len_seq):
  849. if i in ref_pos:
  850. ref_tokens[i] = "##" + ref_tokens[i]
  851. mask_labels.append(self._whole_word_mask(ref_tokens))
  852. batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  853. inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
  854. return {"input_ids": inputs, "labels": labels}
  855. def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
  856. """
  857. Get 0/1 labels for masked tokens with whole word mask proxy
  858. """
  859. if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
  860. warnings.warn(
  861. "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
  862. "Please refer to the documentation for more information."
  863. )
  864. cand_indexes = []
  865. for i, token in enumerate(input_tokens):
  866. if token == "[CLS]" or token == "[SEP]":
  867. continue
  868. if len(cand_indexes) >= 1 and token.startswith("##"):
  869. cand_indexes[-1].append(i)
  870. else:
  871. cand_indexes.append([i])
  872. random.shuffle(cand_indexes)
  873. num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
  874. masked_lms = []
  875. covered_indexes = set()
  876. for index_set in cand_indexes:
  877. if len(masked_lms) >= num_to_predict:
  878. break
  879. # If adding a whole-word mask would exceed the maximum number of
  880. # predictions, then just skip this candidate.
  881. if len(masked_lms) + len(index_set) > num_to_predict:
  882. continue
  883. is_any_index_covered = False
  884. for index in index_set:
  885. if index in covered_indexes:
  886. is_any_index_covered = True
  887. break
  888. if is_any_index_covered:
  889. continue
  890. for index in index_set:
  891. covered_indexes.add(index)
  892. masked_lms.append(index)
  893. if len(covered_indexes) != len(masked_lms):
  894. raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
  895. mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
  896. return mask_labels
  897. def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
  898. """
  899. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
  900. 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
  901. """
  902. import torch
  903. if self.tokenizer.mask_token is None:
  904. raise ValueError(
  905. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
  906. " --mlm flag if you want to use this tokenizer."
  907. )
  908. labels = inputs.clone()
  909. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  910. probability_matrix = mask_labels
  911. special_tokens_mask = [
  912. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  913. ]
  914. probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  915. if self.tokenizer._pad_token is not None:
  916. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  917. probability_matrix.masked_fill_(padding_mask, value=0.0)
  918. masked_indices = probability_matrix.bool()
  919. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  920. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  921. indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
  922. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  923. # 10% of the time, we replace masked input tokens with random word
  924. indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  925. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  926. inputs[indices_random] = random_words[indices_random]
  927. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  928. return inputs, labels
  929. def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
  930. """
  931. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
  932. 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
  933. """
  934. import tensorflow as tf
  935. input_shape = tf.shape(inputs)
  936. if self.tokenizer.mask_token is None:
  937. raise ValueError(
  938. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
  939. " --mlm flag if you want to use this tokenizer."
  940. )
  941. labels = tf.identity(inputs)
  942. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  943. masked_indices = tf.cast(mask_labels, tf.bool)
  944. special_tokens_mask = [
  945. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
  946. ]
  947. masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
  948. if self.tokenizer._pad_token is not None:
  949. padding_mask = inputs == self.tokenizer.pad_token_id
  950. masked_indices = masked_indices & ~padding_mask
  951. # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
  952. labels = tf.where(masked_indices, inputs, -100)
  953. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  954. indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
  955. inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
  956. # 10% of the time, we replace masked input tokens with random word
  957. indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
  958. random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
  959. inputs = tf.where(indices_random, random_words, inputs)
  960. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  961. return inputs, labels
  962. def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
  963. """
  964. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
  965. 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
  966. """
  967. if self.tokenizer.mask_token is None:
  968. raise ValueError(
  969. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
  970. " --mlm flag if you want to use this tokenizer."
  971. )
  972. labels = np.copy(inputs)
  973. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  974. masked_indices = mask_labels.astype(bool)
  975. special_tokens_mask = [
  976. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  977. ]
  978. masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
  979. if self.tokenizer._pad_token is not None:
  980. padding_mask = labels == self.tokenizer.pad_token_id
  981. masked_indices[padding_mask] = 0
  982. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  983. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  984. indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
  985. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  986. # 10% of the time, we replace masked input tokens with random word
  987. # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  988. indices_random = (
  989. np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
  990. )
  991. random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
  992. inputs[indices_random] = random_words[indices_random]
  993. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  994. return inputs, labels
  995. @dataclass
  996. class DataCollatorForSOP(DataCollatorForLanguageModeling):
  997. """
  998. Data collator used for sentence order prediction task.
  999. - collates batches of tensors, honoring their tokenizer's pad_token
  1000. - preprocesses batches for both masked language modeling and sentence order prediction
  1001. """
  1002. def __init__(self, *args, **kwargs):
  1003. warnings.warn(
  1004. "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
  1005. "DataCollatorForLanguageModeling instead.",
  1006. FutureWarning,
  1007. )
  1008. def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
  1009. import torch
  1010. from torch.nn.utils.rnn import pad_sequence
  1011. input_ids = [example["input_ids"] for example in examples]
  1012. input_ids = _torch_collate_batch(input_ids, self.tokenizer)
  1013. input_ids, labels, attention_mask = self.mask_tokens(input_ids)
  1014. token_type_ids = [example["token_type_ids"] for example in examples]
  1015. # size of segment_ids varied because randomness, padding zero to the end as the original implementation
  1016. token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
  1017. sop_label_list = [example["sentence_order_label"] for example in examples]
  1018. sentence_order_label = torch.stack(sop_label_list)
  1019. return {
  1020. "input_ids": input_ids,
  1021. "labels": labels,
  1022. "attention_mask": attention_mask,
  1023. "token_type_ids": token_type_ids,
  1024. "sentence_order_label": sentence_order_label,
  1025. }
  1026. def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
  1027. """
  1028. Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
  1029. original. N-gram not applied yet.
  1030. """
  1031. import torch
  1032. if self.tokenizer.mask_token is None:
  1033. raise ValueError(
  1034. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
  1035. " --mlm flag if you want to use this tokenizer."
  1036. )
  1037. labels = inputs.clone()
  1038. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  1039. probability_matrix = torch.full(labels.shape, self.mlm_probability)
  1040. special_tokens_mask = [
  1041. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  1042. ]
  1043. probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  1044. if self.tokenizer._pad_token is not None:
  1045. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  1046. probability_matrix.masked_fill_(padding_mask, value=0.0)
  1047. masked_indices = torch.bernoulli(probability_matrix).bool()
  1048. # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
  1049. attention_mask = (~masked_indices).float()
  1050. if self.tokenizer._pad_token is not None:
  1051. attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
  1052. attention_mask.masked_fill_(attention_padding_mask, value=1.0)
  1053. labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
  1054. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  1055. indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
  1056. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  1057. # 10% of the time, we replace masked input tokens with random word
  1058. indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  1059. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  1060. inputs[indices_random] = random_words[indices_random]
  1061. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  1062. return inputs, labels, attention_mask
  1063. @dataclass
  1064. class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
  1065. """
  1066. Data collator used for permutation language modeling.
  1067. - collates batches of tensors, honoring their tokenizer's pad_token
  1068. - preprocesses batches for permutation language modeling with procedures specific to XLNet
  1069. """
  1070. tokenizer: PreTrainedTokenizerBase
  1071. plm_probability: float = 1 / 6
  1072. max_span_length: int = 5 # maximum length of a span of masked tokens
  1073. return_tensors: str = "pt"
  1074. def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  1075. if isinstance(examples[0], Mapping):
  1076. examples = [e["input_ids"] for e in examples]
  1077. batch = _torch_collate_batch(examples, self.tokenizer)
  1078. inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
  1079. return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
  1080. def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  1081. if isinstance(examples[0], Mapping):
  1082. examples = [e["input_ids"] for e in examples]
  1083. batch = _tf_collate_batch(examples, self.tokenizer)
  1084. inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
  1085. return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
  1086. def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
  1087. if isinstance(examples[0], Mapping):
  1088. examples = [e["input_ids"] for e in examples]
  1089. batch = _numpy_collate_batch(examples, self.tokenizer)
  1090. inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
  1091. return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
  1092. def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
  1093. """
  1094. The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
  1095. 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1096. 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1097. 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
  1098. masked
  1099. 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
  1100. span_length]` and mask tokens `start_index:start_index + span_length`
  1101. 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
  1102. sequence to be processed), repeat from Step 1.
  1103. """
  1104. import torch
  1105. if self.tokenizer.mask_token is None:
  1106. raise ValueError(
  1107. "This tokenizer does not have a mask token which is necessary for permutation language modeling."
  1108. " Please add a mask token if you want to use this tokenizer."
  1109. )
  1110. if inputs.size(1) % 2 != 0:
  1111. raise ValueError(
  1112. "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
  1113. " relevant comments in source code for details."
  1114. )
  1115. labels = inputs.clone()
  1116. # Creating the mask and target_mapping tensors
  1117. masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
  1118. target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
  1119. for i in range(labels.size(0)):
  1120. # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1121. cur_len = 0
  1122. max_len = labels.size(1)
  1123. while cur_len < max_len:
  1124. # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1125. span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
  1126. # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
  1127. context_length = int(span_length / self.plm_probability)
  1128. # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
  1129. start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
  1130. masked_indices[i, start_index : start_index + span_length] = 1
  1131. # Set `cur_len = cur_len + context_length`
  1132. cur_len += context_length
  1133. # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
  1134. # the i-th predict corresponds to the i-th token.
  1135. target_mapping[i] = torch.eye(labels.size(1))
  1136. special_tokens_mask = torch.tensor(
  1137. [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
  1138. dtype=torch.bool,
  1139. )
  1140. masked_indices.masked_fill_(special_tokens_mask, value=0.0)
  1141. if self.tokenizer._pad_token is not None:
  1142. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  1143. masked_indices.masked_fill_(padding_mask, value=0.0)
  1144. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
  1145. non_func_mask = ~(padding_mask | special_tokens_mask)
  1146. inputs[masked_indices] = self.tokenizer.mask_token_id
  1147. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  1148. perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
  1149. for i in range(labels.size(0)):
  1150. # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
  1151. # determine which tokens a given token can attend to (encoded in `perm_mask`).
  1152. # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
  1153. # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
  1154. # we assume that reused length is half of sequence length and permutation length is equal to reused length.
  1155. # This requires that the sequence length be even.
  1156. # Create a linear factorisation order
  1157. perm_index = torch.arange(labels.size(1))
  1158. # Split this into two halves, assuming that half the sequence is reused each time
  1159. perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
  1160. # Permute the two halves such that they do not cross over
  1161. perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
  1162. # Flatten this out into the desired permuted factorisation order
  1163. perm_index = torch.flatten(perm_index.transpose(0, 1))
  1164. # Set the permutation indices of non-masked (non-functional) tokens to the
  1165. # smallest index (-1) so that:
  1166. # (1) They can be seen by all other positions
  1167. # (2) They cannot see masked positions, so there won't be information leak
  1168. perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
  1169. # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
  1170. # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
  1171. # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
  1172. perm_mask[i] = (
  1173. perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
  1174. ) & masked_indices[i]
  1175. return inputs.long(), perm_mask, target_mapping, labels.long()
  1176. def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
  1177. """
  1178. The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
  1179. 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1180. 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1181. 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
  1182. masked
  1183. 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
  1184. span_length]` and mask tokens `start_index:start_index + span_length`
  1185. 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
  1186. sequence to be processed), repeat from Step 1.
  1187. """
  1188. import tensorflow as tf
  1189. if self.tokenizer.mask_token is None:
  1190. raise ValueError(
  1191. "This tokenizer does not have a mask token which is necessary for permutation language modeling."
  1192. " Please add a mask token if you want to use this tokenizer."
  1193. )
  1194. if tf.shape(inputs)[1] % 2 != 0:
  1195. raise ValueError(
  1196. "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
  1197. " relevant comments in source code for details."
  1198. )
  1199. labels = tf.identity(inputs)
  1200. # Creating the mask and target_mapping tensors
  1201. masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
  1202. labels_shape = tf.shape(labels)
  1203. target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
  1204. for i in range(len(labels)):
  1205. # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1206. cur_len = 0
  1207. max_len = tf.shape(labels)[1]
  1208. while cur_len < max_len:
  1209. # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1210. span_length = randint(1, self.max_span_length + 1)
  1211. # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
  1212. context_length = int(span_length / self.plm_probability)
  1213. # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
  1214. start_index = cur_len + randint(0, context_length - span_length + 1)
  1215. masked_indices[i, start_index : start_index + span_length] = 1
  1216. # Set `cur_len = cur_len + context_length`
  1217. cur_len += context_length
  1218. # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
  1219. # the i-th predict corresponds to the i-th token.
  1220. target_mapping[i] = np.eye(labels_shape[1])
  1221. masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
  1222. target_mapping = tf.convert_to_tensor(target_mapping)
  1223. special_tokens_mask = tf.convert_to_tensor(
  1224. [
  1225. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
  1226. for val in labels.numpy().tolist()
  1227. ],
  1228. )
  1229. special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
  1230. masked_indices = masked_indices & ~special_tokens_mask
  1231. if self.tokenizer._pad_token is not None:
  1232. padding_mask = labels == self.tokenizer.pad_token_id
  1233. masked_indices = masked_indices & ~padding_mask
  1234. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
  1235. non_func_mask = ~(padding_mask | special_tokens_mask)
  1236. inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
  1237. labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
  1238. perm_mask = []
  1239. for i in range(len(labels)):
  1240. # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
  1241. # determine which tokens a given token can attend to (encoded in `perm_mask`).
  1242. # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
  1243. # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
  1244. # we assume that reused length is half of sequence length and permutation length is equal to reused length.
  1245. # This requires that the sequence length be even.
  1246. # Create a linear factorisation order
  1247. # tf.range is the equivalent of torch.arange
  1248. perm_index = tf.range(labels_shape[1])
  1249. # Split this into two halves, assuming that half the sequence is reused each time
  1250. perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
  1251. # Permute the two halves such that they do not cross over
  1252. perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
  1253. # Flatten this out into the desired permuted factorisation order
  1254. perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
  1255. # Set the permutation indices of non-masked (non-functional) tokens to the
  1256. # smallest index (-1) so that:
  1257. # (1) They can be seen by all other positions
  1258. # (2) They cannot see masked positions, so there won't be information leak
  1259. perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
  1260. # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
  1261. # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
  1262. # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
  1263. perm_mask.append(
  1264. (tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
  1265. & masked_indices[i]
  1266. )
  1267. perm_mask = tf.stack(perm_mask, axis=0)
  1268. return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
  1269. def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
  1270. """
  1271. The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
  1272. 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1273. 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1274. 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
  1275. masked
  1276. 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
  1277. span_length]` and mask tokens `start_index:start_index + span_length`
  1278. 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
  1279. sequence to be processed), repeat from Step 1.
  1280. """
  1281. if self.tokenizer.mask_token is None:
  1282. raise ValueError(
  1283. "This tokenizer does not have a mask token which is necessary for permutation language modeling."
  1284. " Please add a mask token if you want to use this tokenizer."
  1285. )
  1286. if inputs.shape[1] % 2 != 0:
  1287. raise ValueError(
  1288. "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
  1289. " relevant comments in source code for details."
  1290. )
  1291. labels = np.copy(inputs)
  1292. # Creating the mask and target_mapping tensors
  1293. masked_indices = np.full(labels.shape, 0, dtype=bool)
  1294. target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
  1295. for i in range(labels.shape[0]):
  1296. # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1297. cur_len = 0
  1298. max_len = labels.shape[1]
  1299. while cur_len < max_len:
  1300. # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1301. span_length = randint(1, self.max_span_length + 1)
  1302. # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
  1303. context_length = int(span_length / self.plm_probability)
  1304. # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
  1305. start_index = cur_len + randint(0, context_length - span_length + 1)
  1306. masked_indices[i, start_index : start_index + span_length] = 1
  1307. # Set `cur_len = cur_len + context_length`
  1308. cur_len += context_length
  1309. # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
  1310. # the i-th predict corresponds to the i-th token.
  1311. target_mapping[i] = np.eye(labels.shape[1])
  1312. special_tokens_mask = np.array(
  1313. [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
  1314. dtype=bool,
  1315. )
  1316. masked_indices[special_tokens_mask] = 0
  1317. if self.tokenizer._pad_token is not None:
  1318. padding_mask = labels == self.tokenizer.pad_token_id
  1319. masked_indices[padding_mask] = 0.0
  1320. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
  1321. non_func_mask = ~(padding_mask | special_tokens_mask)
  1322. inputs[masked_indices] = self.tokenizer.mask_token_id
  1323. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  1324. perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
  1325. for i in range(labels.shape[0]):
  1326. # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
  1327. # determine which tokens a given token can attend to (encoded in `perm_mask`).
  1328. # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
  1329. # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
  1330. # we assume that reused length is half of sequence length and permutation length is equal to reused length.
  1331. # This requires that the sequence length be even.
  1332. # Create a linear factorisation order
  1333. perm_index = np.arange(labels.shape[1])
  1334. # Split this into two halves, assuming that half the sequence is reused each time
  1335. perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
  1336. # Permute the two halves such that they do not cross over
  1337. np.random.shuffle(perm_index)
  1338. # Flatten this out into the desired permuted factorisation order
  1339. perm_index = perm_index.T.flatten()
  1340. # Set the permutation indices of non-masked (non-functional) tokens to the
  1341. # smallest index (-1) so that:
  1342. # (1) They can be seen by all other positions
  1343. # (2) They cannot see masked positions, so there won't be information leak
  1344. perm_index[~masked_indices[i] & non_func_mask[i]] = -1
  1345. # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
  1346. # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
  1347. # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
  1348. perm_mask[i] = (
  1349. perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
  1350. ) & masked_indices[i]
  1351. return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
  1352. @dataclass
  1353. class DataCollatorWithFlattening(DefaultDataCollator):
  1354. """
  1355. Data collator used for padding free approach. Does the following:
  1356. - concatate the entire mini batch into single long sequence [1, total_tokens]
  1357. - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
  1358. - no padding will be added, returns `input_ids`, `labels` and `position_ids`
  1359. """
  1360. def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
  1361. super().__init__(*args, **kwargs)
  1362. self.return_position_ids = return_position_ids
  1363. self.separator_id = separator_id
  1364. warnings.warn(
  1365. "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
  1366. "Make sure your attention computation is able to handle it!"
  1367. )
  1368. def __call__(self, features, return_tensors=None, separator_id=None):
  1369. if return_tensors is None:
  1370. return_tensors = self.return_tensors
  1371. if separator_id is None:
  1372. separator_id = self.separator_id
  1373. is_labels_provided = "labels" in features[0]
  1374. ret = {"input_ids": [], "labels": []}
  1375. if self.return_position_ids:
  1376. ret.update({"position_ids": []})
  1377. for idx in range(0, len(features)):
  1378. ret["input_ids"] += features[idx]["input_ids"]
  1379. if is_labels_provided:
  1380. ret["labels"] += [separator_id] + features[idx]["labels"][1:]
  1381. else:
  1382. ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
  1383. if self.return_position_ids:
  1384. ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
  1385. return default_data_collator([ret], return_tensors)