language_modeling.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import pickle
  17. import random
  18. import time
  19. import warnings
  20. from typing import Dict, List, Optional
  21. import torch
  22. from filelock import FileLock
  23. from torch.utils.data import Dataset
  24. from ...tokenization_utils import PreTrainedTokenizer
  25. from ...utils import logging
  26. logger = logging.get_logger(__name__)
  27. DEPRECATION_WARNING = (
  28. "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
  29. "library. You can have a look at this example script for pointers: {0}"
  30. )
  31. class TextDataset(Dataset):
  32. """
  33. This will be superseded by a framework-agnostic approach soon.
  34. """
  35. def __init__(
  36. self,
  37. tokenizer: PreTrainedTokenizer,
  38. file_path: str,
  39. block_size: int,
  40. overwrite_cache=False,
  41. cache_dir: Optional[str] = None,
  42. ):
  43. warnings.warn(
  44. DEPRECATION_WARNING.format(
  45. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
  46. ),
  47. FutureWarning,
  48. )
  49. if os.path.isfile(file_path) is False:
  50. raise ValueError(f"Input file path {file_path} not found")
  51. block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
  52. directory, filename = os.path.split(file_path)
  53. cached_features_file = os.path.join(
  54. cache_dir if cache_dir is not None else directory,
  55. f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
  56. )
  57. # Make sure only the first process in distributed training processes the dataset,
  58. # and the others will use the cache.
  59. lock_path = cached_features_file + ".lock"
  60. with FileLock(lock_path):
  61. if os.path.exists(cached_features_file) and not overwrite_cache:
  62. start = time.time()
  63. with open(cached_features_file, "rb") as handle:
  64. self.examples = pickle.load(handle)
  65. logger.info(
  66. f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
  67. )
  68. else:
  69. logger.info(f"Creating features from dataset file at {directory}")
  70. self.examples = []
  71. with open(file_path, encoding="utf-8") as f:
  72. text = f.read()
  73. tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
  74. for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
  75. self.examples.append(
  76. tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
  77. )
  78. # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
  79. # If your dataset is small, first you should look for a bigger one :-) and second you
  80. # can change this behavior by adding (model specific) padding.
  81. start = time.time()
  82. with open(cached_features_file, "wb") as handle:
  83. pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
  84. logger.info(
  85. f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
  86. )
  87. def __len__(self):
  88. return len(self.examples)
  89. def __getitem__(self, i) -> torch.Tensor:
  90. return torch.tensor(self.examples[i], dtype=torch.long)
  91. class LineByLineTextDataset(Dataset):
  92. """
  93. This will be superseded by a framework-agnostic approach soon.
  94. """
  95. def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
  96. warnings.warn(
  97. DEPRECATION_WARNING.format(
  98. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
  99. ),
  100. FutureWarning,
  101. )
  102. if os.path.isfile(file_path) is False:
  103. raise ValueError(f"Input file path {file_path} not found")
  104. # Here, we do not cache the features, operating under the assumption
  105. # that we will soon use fast multithreaded tokenizers from the
  106. # `tokenizers` repo everywhere =)
  107. logger.info(f"Creating features from dataset file at {file_path}")
  108. with open(file_path, encoding="utf-8") as f:
  109. lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
  110. batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
  111. self.examples = batch_encoding["input_ids"]
  112. self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
  113. def __len__(self):
  114. return len(self.examples)
  115. def __getitem__(self, i) -> Dict[str, torch.tensor]:
  116. return self.examples[i]
  117. class LineByLineWithRefDataset(Dataset):
  118. """
  119. This will be superseded by a framework-agnostic approach soon.
  120. """
  121. def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
  122. warnings.warn(
  123. DEPRECATION_WARNING.format(
  124. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
  125. ),
  126. FutureWarning,
  127. )
  128. if os.path.isfile(file_path) is False:
  129. raise ValueError(f"Input file path {file_path} not found")
  130. if os.path.isfile(ref_path) is False:
  131. raise ValueError(f"Ref file path {file_path} not found")
  132. # Here, we do not cache the features, operating under the assumption
  133. # that we will soon use fast multithreaded tokenizers from the
  134. # `tokenizers` repo everywhere =)
  135. logger.info(f"Creating features from dataset file at {file_path}")
  136. logger.info(f"Use ref segment results at {ref_path}")
  137. with open(file_path, encoding="utf-8") as f:
  138. data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
  139. data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
  140. # Get ref inf from file
  141. with open(ref_path, encoding="utf-8") as f:
  142. ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
  143. if len(data) != len(ref):
  144. raise ValueError(
  145. f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
  146. f"while length of {ref_path} is {len(ref)}"
  147. )
  148. batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
  149. self.examples = batch_encoding["input_ids"]
  150. self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
  151. n = len(self.examples)
  152. for i in range(n):
  153. self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
  154. def __len__(self):
  155. return len(self.examples)
  156. def __getitem__(self, i) -> Dict[str, torch.tensor]:
  157. return self.examples[i]
  158. class LineByLineWithSOPTextDataset(Dataset):
  159. """
  160. Dataset for sentence order prediction task, prepare sentence pairs for SOP task
  161. """
  162. def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
  163. warnings.warn(
  164. DEPRECATION_WARNING.format(
  165. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
  166. ),
  167. FutureWarning,
  168. )
  169. if os.path.isdir(file_dir) is False:
  170. raise ValueError(f"{file_dir} is not a directory")
  171. logger.info(f"Creating features from dataset file folder at {file_dir}")
  172. self.examples = []
  173. # TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
  174. # file path looks like ./dataset/wiki_1, ./dataset/wiki_2
  175. for file_name in os.listdir(file_dir):
  176. file_path = os.path.join(file_dir, file_name)
  177. if os.path.isfile(file_path) is False:
  178. raise ValueError(f"{file_path} is not a file")
  179. article_open = False
  180. with open(file_path, encoding="utf-8") as f:
  181. original_lines = f.readlines()
  182. article_lines = []
  183. for line in original_lines:
  184. if "<doc id=" in line:
  185. article_open = True
  186. elif "</doc>" in line:
  187. article_open = False
  188. document = [
  189. tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
  190. for line in article_lines[1:]
  191. if (len(line) > 0 and not line.isspace())
  192. ]
  193. examples = self.create_examples_from_document(document, block_size, tokenizer)
  194. self.examples.extend(examples)
  195. article_lines = []
  196. else:
  197. if article_open:
  198. article_lines.append(line)
  199. logger.info("Dataset parse finished.")
  200. def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
  201. """Creates examples for a single document."""
  202. # Account for special tokens
  203. max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
  204. # We *usually* want to fill up the entire sequence since we are padding
  205. # to `block_size` anyways, so short sequences are generally wasted
  206. # computation. However, we *sometimes*
  207. # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
  208. # sequences to minimize the mismatch between pretraining and fine-tuning.
  209. # The `target_seq_length` is just a rough target however, whereas
  210. # `block_size` is a hard limit.
  211. target_seq_length = max_num_tokens
  212. if random.random() < short_seq_prob:
  213. target_seq_length = random.randint(2, max_num_tokens)
  214. # We DON'T just concatenate all of the tokens from a document into a long
  215. # sequence and choose an arbitrary split point because this would make the
  216. # next sentence prediction task too easy. Instead, we split the input into
  217. # segments "A" and "B" based on the actual "sentences" provided by the user
  218. # input.
  219. examples = []
  220. current_chunk = [] # a buffer stored current working segments
  221. current_length = 0
  222. i = 0
  223. while i < len(document):
  224. segment = document[i] # get a segment
  225. if not segment:
  226. i += 1
  227. continue
  228. current_chunk.append(segment) # add a segment to current chunk
  229. current_length += len(segment) # overall token length
  230. # if current length goes to the target length or reaches the end of file, start building token a and b
  231. if i == len(document) - 1 or current_length >= target_seq_length:
  232. if current_chunk:
  233. # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
  234. a_end = 1
  235. # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
  236. if len(current_chunk) >= 2:
  237. a_end = random.randint(1, len(current_chunk) - 1)
  238. # token a
  239. tokens_a = []
  240. for j in range(a_end):
  241. tokens_a.extend(current_chunk[j])
  242. # token b
  243. tokens_b = []
  244. for j in range(a_end, len(current_chunk)):
  245. tokens_b.extend(current_chunk[j])
  246. if len(tokens_a) == 0 or len(tokens_b) == 0:
  247. continue
  248. # switch tokens_a and tokens_b randomly
  249. if random.random() < 0.5:
  250. is_next = False
  251. tokens_a, tokens_b = tokens_b, tokens_a
  252. else:
  253. is_next = True
  254. def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
  255. """Truncates a pair of sequences to a maximum sequence length."""
  256. while True:
  257. total_length = len(tokens_a) + len(tokens_b)
  258. if total_length <= max_num_tokens:
  259. break
  260. trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
  261. if not (len(trunc_tokens) >= 1):
  262. raise ValueError("Sequence length to be truncated must be no less than one")
  263. # We want to sometimes truncate from the front and sometimes from the
  264. # back to add more randomness and avoid biases.
  265. if random.random() < 0.5:
  266. del trunc_tokens[0]
  267. else:
  268. trunc_tokens.pop()
  269. truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
  270. if not (len(tokens_a) >= 1):
  271. raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
  272. if not (len(tokens_b) >= 1):
  273. raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
  274. # add special tokens
  275. input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
  276. # add token type ids, 0 for sentence a, 1 for sentence b
  277. token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
  278. example = {
  279. "input_ids": torch.tensor(input_ids, dtype=torch.long),
  280. "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
  281. "sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
  282. }
  283. examples.append(example)
  284. current_chunk = [] # clear current chunk
  285. current_length = 0 # reset current text length
  286. i += 1 # go to next line
  287. return examples
  288. def __len__(self):
  289. return len(self.examples)
  290. def __getitem__(self, i) -> Dict[str, torch.tensor]:
  291. return self.examples[i]
  292. class TextDatasetForNextSentencePrediction(Dataset):
  293. """
  294. This will be superseded by a framework-agnostic approach soon.
  295. """
  296. def __init__(
  297. self,
  298. tokenizer: PreTrainedTokenizer,
  299. file_path: str,
  300. block_size: int,
  301. overwrite_cache=False,
  302. short_seq_probability=0.1,
  303. nsp_probability=0.5,
  304. ):
  305. warnings.warn(
  306. DEPRECATION_WARNING.format(
  307. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
  308. ),
  309. FutureWarning,
  310. )
  311. if not os.path.isfile(file_path):
  312. raise ValueError(f"Input file path {file_path} not found")
  313. self.short_seq_probability = short_seq_probability
  314. self.nsp_probability = nsp_probability
  315. directory, filename = os.path.split(file_path)
  316. cached_features_file = os.path.join(
  317. directory,
  318. f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
  319. )
  320. self.tokenizer = tokenizer
  321. # Make sure only the first process in distributed training processes the dataset,
  322. # and the others will use the cache.
  323. lock_path = cached_features_file + ".lock"
  324. # Input file format:
  325. # (1) One sentence per line. These should ideally be actual sentences, not
  326. # entire paragraphs or arbitrary spans of text. (Because we use the
  327. # sentence boundaries for the "next sentence prediction" task).
  328. # (2) Blank lines between documents. Document boundaries are needed so
  329. # that the "next sentence prediction" task doesn't span between documents.
  330. #
  331. # Example:
  332. # I am very happy.
  333. # Here is the second sentence.
  334. #
  335. # A new document.
  336. with FileLock(lock_path):
  337. if os.path.exists(cached_features_file) and not overwrite_cache:
  338. start = time.time()
  339. with open(cached_features_file, "rb") as handle:
  340. self.examples = pickle.load(handle)
  341. logger.info(
  342. f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
  343. )
  344. else:
  345. logger.info(f"Creating features from dataset file at {directory}")
  346. self.documents = [[]]
  347. with open(file_path, encoding="utf-8") as f:
  348. while True:
  349. line = f.readline()
  350. if not line:
  351. break
  352. line = line.strip()
  353. # Empty lines are used as document delimiters
  354. if not line and len(self.documents[-1]) != 0:
  355. self.documents.append([])
  356. tokens = tokenizer.tokenize(line)
  357. tokens = tokenizer.convert_tokens_to_ids(tokens)
  358. if tokens:
  359. self.documents[-1].append(tokens)
  360. logger.info(f"Creating examples from {len(self.documents)} documents.")
  361. self.examples = []
  362. for doc_index, document in enumerate(self.documents):
  363. self.create_examples_from_document(document, doc_index, block_size)
  364. start = time.time()
  365. with open(cached_features_file, "wb") as handle:
  366. pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
  367. logger.info(
  368. f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
  369. )
  370. def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
  371. """Creates examples for a single document."""
  372. max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
  373. # We *usually* want to fill up the entire sequence since we are padding
  374. # to `block_size` anyways, so short sequences are generally wasted
  375. # computation. However, we *sometimes*
  376. # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
  377. # sequences to minimize the mismatch between pretraining and fine-tuning.
  378. # The `target_seq_length` is just a rough target however, whereas
  379. # `block_size` is a hard limit.
  380. target_seq_length = max_num_tokens
  381. if random.random() < self.short_seq_probability:
  382. target_seq_length = random.randint(2, max_num_tokens)
  383. current_chunk = [] # a buffer stored current working segments
  384. current_length = 0
  385. i = 0
  386. while i < len(document):
  387. segment = document[i]
  388. current_chunk.append(segment)
  389. current_length += len(segment)
  390. if i == len(document) - 1 or current_length >= target_seq_length:
  391. if current_chunk:
  392. # `a_end` is how many segments from `current_chunk` go into the `A`
  393. # (first) sentence.
  394. a_end = 1
  395. if len(current_chunk) >= 2:
  396. a_end = random.randint(1, len(current_chunk) - 1)
  397. tokens_a = []
  398. for j in range(a_end):
  399. tokens_a.extend(current_chunk[j])
  400. tokens_b = []
  401. if len(current_chunk) == 1 or random.random() < self.nsp_probability:
  402. is_random_next = True
  403. target_b_length = target_seq_length - len(tokens_a)
  404. # This should rarely go for more than one iteration for large
  405. # corpora. However, just to be careful, we try to make sure that
  406. # the random document is not the same as the document
  407. # we're processing.
  408. for _ in range(10):
  409. random_document_index = random.randint(0, len(self.documents) - 1)
  410. if random_document_index != doc_index:
  411. break
  412. random_document = self.documents[random_document_index]
  413. random_start = random.randint(0, len(random_document) - 1)
  414. for j in range(random_start, len(random_document)):
  415. tokens_b.extend(random_document[j])
  416. if len(tokens_b) >= target_b_length:
  417. break
  418. # We didn't actually use these segments so we "put them back" so
  419. # they don't go to waste.
  420. num_unused_segments = len(current_chunk) - a_end
  421. i -= num_unused_segments
  422. # Actual next
  423. else:
  424. is_random_next = False
  425. for j in range(a_end, len(current_chunk)):
  426. tokens_b.extend(current_chunk[j])
  427. if not (len(tokens_a) >= 1):
  428. raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
  429. if not (len(tokens_b) >= 1):
  430. raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
  431. # add special tokens
  432. input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
  433. # add token type ids, 0 for sentence a, 1 for sentence b
  434. token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
  435. example = {
  436. "input_ids": torch.tensor(input_ids, dtype=torch.long),
  437. "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
  438. "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
  439. }
  440. self.examples.append(example)
  441. current_chunk = []
  442. current_length = 0
  443. i += 1
  444. def __len__(self):
  445. return len(self.examples)
  446. def __getitem__(self, i):
  447. return self.examples[i]