glue.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import warnings
  17. from dataclasses import dataclass, field
  18. from enum import Enum
  19. from typing import List, Optional, Union
  20. import torch
  21. from filelock import FileLock
  22. from torch.utils.data import Dataset
  23. from ...tokenization_utils_base import PreTrainedTokenizerBase
  24. from ...utils import logging
  25. from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
  26. from ..processors.utils import InputFeatures
  27. logger = logging.get_logger(__name__)
  28. @dataclass
  29. class GlueDataTrainingArguments:
  30. """
  31. Arguments pertaining to what data we are going to input our model for training and eval.
  32. Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
  33. line.
  34. """
  35. task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
  36. data_dir: str = field(
  37. metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
  38. )
  39. max_seq_length: int = field(
  40. default=128,
  41. metadata={
  42. "help": (
  43. "The maximum total input sequence length after tokenization. Sequences longer "
  44. "than this will be truncated, sequences shorter will be padded."
  45. )
  46. },
  47. )
  48. overwrite_cache: bool = field(
  49. default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
  50. )
  51. def __post_init__(self):
  52. self.task_name = self.task_name.lower()
  53. class Split(Enum):
  54. train = "train"
  55. dev = "dev"
  56. test = "test"
  57. class GlueDataset(Dataset):
  58. """
  59. This will be superseded by a framework-agnostic approach soon.
  60. """
  61. args: GlueDataTrainingArguments
  62. output_mode: str
  63. features: List[InputFeatures]
  64. def __init__(
  65. self,
  66. args: GlueDataTrainingArguments,
  67. tokenizer: PreTrainedTokenizerBase,
  68. limit_length: Optional[int] = None,
  69. mode: Union[str, Split] = Split.train,
  70. cache_dir: Optional[str] = None,
  71. ):
  72. warnings.warn(
  73. "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
  74. "library. You can have a look at this example script for pointers: "
  75. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
  76. FutureWarning,
  77. )
  78. self.args = args
  79. self.processor = glue_processors[args.task_name]()
  80. self.output_mode = glue_output_modes[args.task_name]
  81. if isinstance(mode, str):
  82. try:
  83. mode = Split[mode]
  84. except KeyError:
  85. raise KeyError("mode is not a valid split name")
  86. # Load data features from cache or dataset file
  87. cached_features_file = os.path.join(
  88. cache_dir if cache_dir is not None else args.data_dir,
  89. f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
  90. )
  91. label_list = self.processor.get_labels()
  92. if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
  93. "RobertaTokenizer",
  94. "RobertaTokenizerFast",
  95. "XLMRobertaTokenizer",
  96. "BartTokenizer",
  97. "BartTokenizerFast",
  98. ):
  99. # HACK(label indices are swapped in RoBERTa pretrained model)
  100. label_list[1], label_list[2] = label_list[2], label_list[1]
  101. self.label_list = label_list
  102. # Make sure only the first process in distributed training processes the dataset,
  103. # and the others will use the cache.
  104. lock_path = cached_features_file + ".lock"
  105. with FileLock(lock_path):
  106. if os.path.exists(cached_features_file) and not args.overwrite_cache:
  107. start = time.time()
  108. self.features = torch.load(cached_features_file)
  109. logger.info(
  110. f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
  111. )
  112. else:
  113. logger.info(f"Creating features from dataset file at {args.data_dir}")
  114. if mode == Split.dev:
  115. examples = self.processor.get_dev_examples(args.data_dir)
  116. elif mode == Split.test:
  117. examples = self.processor.get_test_examples(args.data_dir)
  118. else:
  119. examples = self.processor.get_train_examples(args.data_dir)
  120. if limit_length is not None:
  121. examples = examples[:limit_length]
  122. self.features = glue_convert_examples_to_features(
  123. examples,
  124. tokenizer,
  125. max_length=args.max_seq_length,
  126. label_list=label_list,
  127. output_mode=self.output_mode,
  128. )
  129. start = time.time()
  130. torch.save(self.features, cached_features_file)
  131. # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
  132. logger.info(
  133. f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
  134. )
  135. def __len__(self):
  136. return len(self.features)
  137. def __getitem__(self, i) -> InputFeatures:
  138. return self.features[i]
  139. def get_labels(self):
  140. return self.label_list