glue.py 23 KB


  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """GLUE processors and helpers"""
  17. import os
  18. import warnings
  19. from dataclasses import asdict
  20. from enum import Enum
  21. from typing import List, Optional, Union
  22. from ...tokenization_utils import PreTrainedTokenizer
  23. from ...utils import is_tf_available, logging
  24. from .utils import DataProcessor, InputExample, InputFeatures
  25. if is_tf_available():
  26. import tensorflow as tf
  27. logger = logging.get_logger(__name__)
  28. DEPRECATION_WARNING = (
  29. "This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
  30. "library. You can have a look at this example script for pointers: "
  31. "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
  32. )
  33. def glue_convert_examples_to_features(
  34. examples: Union[List[InputExample], "tf.data.Dataset"],
  35. tokenizer: PreTrainedTokenizer,
  36. max_length: Optional[int] = None,
  37. task=None,
  38. label_list=None,
  39. output_mode=None,
  40. ):
  41. """
  42. Loads a data file into a list of `InputFeatures`
  43. Args:
  44. examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
  45. tokenizer: Instance of a tokenizer that will tokenize the examples
  46. max_length: Maximum example length. Defaults to the tokenizer's max_len
  47. task: GLUE task
  48. label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
  49. output_mode: String indicating the output mode. Either `regression` or `classification`
  50. Returns:
  51. If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
  52. features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
  53. can be fed to the model.
  54. """
  55. warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
  56. if is_tf_available() and isinstance(examples, tf.data.Dataset):
  57. if task is None:
  58. raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
  59. return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
  60. return _glue_convert_examples_to_features(
  61. examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
  62. )
  63. if is_tf_available():
  64. def _tf_glue_convert_examples_to_features(
  65. examples: tf.data.Dataset,
  66. tokenizer: PreTrainedTokenizer,
  67. task=str,
  68. max_length: Optional[int] = None,
  69. ) -> tf.data.Dataset:
  70. """
  71. Returns:
  72. A `tf.data.Dataset` containing the task-specific features.
  73. """
  74. processor = glue_processors[task]()
  75. examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
  76. features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
  77. label_type = tf.float32 if task == "sts-b" else tf.int64
  78. def gen():
  79. for ex in features:
  80. d = {k: v for k, v in asdict(ex).items() if v is not None}
  81. label = d.pop("label")
  82. yield (d, label)
  83. input_names = tokenizer.model_input_names
  84. return tf.data.Dataset.from_generator(
  85. gen,
  86. ({k: tf.int32 for k in input_names}, label_type),
  87. ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
  88. )
  89. def _glue_convert_examples_to_features(
  90. examples: List[InputExample],
  91. tokenizer: PreTrainedTokenizer,
  92. max_length: Optional[int] = None,
  93. task=None,
  94. label_list=None,
  95. output_mode=None,
  96. ):
  97. if max_length is None:
  98. max_length = tokenizer.model_max_length
  99. if task is not None:
  100. processor = glue_processors[task]()
  101. if label_list is None:
  102. label_list = processor.get_labels()
  103. logger.info(f"Using label list {label_list} for task {task}")
  104. if output_mode is None:
  105. output_mode = glue_output_modes[task]
  106. logger.info(f"Using output mode {output_mode} for task {task}")
  107. label_map = {label: i for i, label in enumerate(label_list)}
  108. def label_from_example(example: InputExample) -> Union[int, float, None]:
  109. if example.label is None:
  110. return None
  111. if output_mode == "classification":
  112. return label_map[example.label]
  113. elif output_mode == "regression":
  114. return float(example.label)
  115. raise KeyError(output_mode)
  116. labels = [label_from_example(example) for example in examples]
  117. batch_encoding = tokenizer(
  118. [(example.text_a, example.text_b) for example in examples],
  119. max_length=max_length,
  120. padding="max_length",
  121. truncation=True,
  122. )
  123. features = []
  124. for i in range(len(examples)):
  125. inputs = {k: batch_encoding[k][i] for k in batch_encoding}
  126. feature = InputFeatures(**inputs, label=labels[i])
  127. features.append(feature)
  128. for i, example in enumerate(examples[:5]):
  129. logger.info("*** Example ***")
  130. logger.info(f"guid: {example.guid}")
  131. logger.info(f"features: {features[i]}")
  132. return features
  133. class OutputMode(Enum):
  134. classification = "classification"
  135. regression = "regression"
  136. class MrpcProcessor(DataProcessor):
  137. """Processor for the MRPC data set (GLUE version)."""
  138. def __init__(self, *args, **kwargs):
  139. super().__init__(*args, **kwargs)
  140. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  141. def get_example_from_tensor_dict(self, tensor_dict):
  142. """See base class."""
  143. return InputExample(
  144. tensor_dict["idx"].numpy(),
  145. tensor_dict["sentence1"].numpy().decode("utf-8"),
  146. tensor_dict["sentence2"].numpy().decode("utf-8"),
  147. str(tensor_dict["label"].numpy()),
  148. )
  149. def get_train_examples(self, data_dir):
  150. """See base class."""
  151. logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
  152. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  153. def get_dev_examples(self, data_dir):
  154. """See base class."""
  155. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  156. def get_test_examples(self, data_dir):
  157. """See base class."""
  158. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  159. def get_labels(self):
  160. """See base class."""
  161. return ["0", "1"]
  162. def _create_examples(self, lines, set_type):
  163. """Creates examples for the training, dev and test sets."""
  164. examples = []
  165. for i, line in enumerate(lines):
  166. if i == 0:
  167. continue
  168. guid = f"{set_type}-{i}"
  169. text_a = line[3]
  170. text_b = line[4]
  171. label = None if set_type == "test" else line[0]
  172. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  173. return examples
  174. class MnliProcessor(DataProcessor):
  175. """Processor for the MultiNLI data set (GLUE version)."""
  176. def __init__(self, *args, **kwargs):
  177. super().__init__(*args, **kwargs)
  178. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  179. def get_example_from_tensor_dict(self, tensor_dict):
  180. """See base class."""
  181. return InputExample(
  182. tensor_dict["idx"].numpy(),
  183. tensor_dict["premise"].numpy().decode("utf-8"),
  184. tensor_dict["hypothesis"].numpy().decode("utf-8"),
  185. str(tensor_dict["label"].numpy()),
  186. )
  187. def get_train_examples(self, data_dir):
  188. """See base class."""
  189. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  190. def get_dev_examples(self, data_dir):
  191. """See base class."""
  192. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
  193. def get_test_examples(self, data_dir):
  194. """See base class."""
  195. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
  196. def get_labels(self):
  197. """See base class."""
  198. return ["contradiction", "entailment", "neutral"]
  199. def _create_examples(self, lines, set_type):
  200. """Creates examples for the training, dev and test sets."""
  201. examples = []
  202. for i, line in enumerate(lines):
  203. if i == 0:
  204. continue
  205. guid = f"{set_type}-{line[0]}"
  206. text_a = line[8]
  207. text_b = line[9]
  208. label = None if set_type.startswith("test") else line[-1]
  209. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  210. return examples
  211. class MnliMismatchedProcessor(MnliProcessor):
  212. """Processor for the MultiNLI Mismatched data set (GLUE version)."""
  213. def __init__(self, *args, **kwargs):
  214. super().__init__(*args, **kwargs)
  215. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  216. def get_dev_examples(self, data_dir):
  217. """See base class."""
  218. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
  219. def get_test_examples(self, data_dir):
  220. """See base class."""
  221. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
  222. class ColaProcessor(DataProcessor):
  223. """Processor for the CoLA data set (GLUE version)."""
  224. def __init__(self, *args, **kwargs):
  225. super().__init__(*args, **kwargs)
  226. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  227. def get_example_from_tensor_dict(self, tensor_dict):
  228. """See base class."""
  229. return InputExample(
  230. tensor_dict["idx"].numpy(),
  231. tensor_dict["sentence"].numpy().decode("utf-8"),
  232. None,
  233. str(tensor_dict["label"].numpy()),
  234. )
  235. def get_train_examples(self, data_dir):
  236. """See base class."""
  237. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  238. def get_dev_examples(self, data_dir):
  239. """See base class."""
  240. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  241. def get_test_examples(self, data_dir):
  242. """See base class."""
  243. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  244. def get_labels(self):
  245. """See base class."""
  246. return ["0", "1"]
  247. def _create_examples(self, lines, set_type):
  248. """Creates examples for the training, dev and test sets."""
  249. test_mode = set_type == "test"
  250. if test_mode:
  251. lines = lines[1:]
  252. text_index = 1 if test_mode else 3
  253. examples = []
  254. for i, line in enumerate(lines):
  255. guid = f"{set_type}-{i}"
  256. text_a = line[text_index]
  257. label = None if test_mode else line[1]
  258. examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  259. return examples
  260. class Sst2Processor(DataProcessor):
  261. """Processor for the SST-2 data set (GLUE version)."""
  262. def __init__(self, *args, **kwargs):
  263. super().__init__(*args, **kwargs)
  264. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  265. def get_example_from_tensor_dict(self, tensor_dict):
  266. """See base class."""
  267. return InputExample(
  268. tensor_dict["idx"].numpy(),
  269. tensor_dict["sentence"].numpy().decode("utf-8"),
  270. None,
  271. str(tensor_dict["label"].numpy()),
  272. )
  273. def get_train_examples(self, data_dir):
  274. """See base class."""
  275. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  276. def get_dev_examples(self, data_dir):
  277. """See base class."""
  278. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  279. def get_test_examples(self, data_dir):
  280. """See base class."""
  281. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  282. def get_labels(self):
  283. """See base class."""
  284. return ["0", "1"]
  285. def _create_examples(self, lines, set_type):
  286. """Creates examples for the training, dev and test sets."""
  287. examples = []
  288. text_index = 1 if set_type == "test" else 0
  289. for i, line in enumerate(lines):
  290. if i == 0:
  291. continue
  292. guid = f"{set_type}-{i}"
  293. text_a = line[text_index]
  294. label = None if set_type == "test" else line[1]
  295. examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
  296. return examples
  297. class StsbProcessor(DataProcessor):
  298. """Processor for the STS-B data set (GLUE version)."""
  299. def __init__(self, *args, **kwargs):
  300. super().__init__(*args, **kwargs)
  301. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  302. def get_example_from_tensor_dict(self, tensor_dict):
  303. """See base class."""
  304. return InputExample(
  305. tensor_dict["idx"].numpy(),
  306. tensor_dict["sentence1"].numpy().decode("utf-8"),
  307. tensor_dict["sentence2"].numpy().decode("utf-8"),
  308. str(tensor_dict["label"].numpy()),
  309. )
  310. def get_train_examples(self, data_dir):
  311. """See base class."""
  312. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  313. def get_dev_examples(self, data_dir):
  314. """See base class."""
  315. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  316. def get_test_examples(self, data_dir):
  317. """See base class."""
  318. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  319. def get_labels(self):
  320. """See base class."""
  321. return [None]
  322. def _create_examples(self, lines, set_type):
  323. """Creates examples for the training, dev and test sets."""
  324. examples = []
  325. for i, line in enumerate(lines):
  326. if i == 0:
  327. continue
  328. guid = f"{set_type}-{line[0]}"
  329. text_a = line[7]
  330. text_b = line[8]
  331. label = None if set_type == "test" else line[-1]
  332. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  333. return examples
  334. class QqpProcessor(DataProcessor):
  335. """Processor for the QQP data set (GLUE version)."""
  336. def __init__(self, *args, **kwargs):
  337. super().__init__(*args, **kwargs)
  338. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  339. def get_example_from_tensor_dict(self, tensor_dict):
  340. """See base class."""
  341. return InputExample(
  342. tensor_dict["idx"].numpy(),
  343. tensor_dict["question1"].numpy().decode("utf-8"),
  344. tensor_dict["question2"].numpy().decode("utf-8"),
  345. str(tensor_dict["label"].numpy()),
  346. )
  347. def get_train_examples(self, data_dir):
  348. """See base class."""
  349. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  350. def get_dev_examples(self, data_dir):
  351. """See base class."""
  352. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  353. def get_test_examples(self, data_dir):
  354. """See base class."""
  355. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  356. def get_labels(self):
  357. """See base class."""
  358. return ["0", "1"]
  359. def _create_examples(self, lines, set_type):
  360. """Creates examples for the training, dev and test sets."""
  361. test_mode = set_type == "test"
  362. q1_index = 1 if test_mode else 3
  363. q2_index = 2 if test_mode else 4
  364. examples = []
  365. for i, line in enumerate(lines):
  366. if i == 0:
  367. continue
  368. guid = f"{set_type}-{line[0]}"
  369. try:
  370. text_a = line[q1_index]
  371. text_b = line[q2_index]
  372. label = None if test_mode else line[5]
  373. except IndexError:
  374. continue
  375. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  376. return examples
  377. class QnliProcessor(DataProcessor):
  378. """Processor for the QNLI data set (GLUE version)."""
  379. def __init__(self, *args, **kwargs):
  380. super().__init__(*args, **kwargs)
  381. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  382. def get_example_from_tensor_dict(self, tensor_dict):
  383. """See base class."""
  384. return InputExample(
  385. tensor_dict["idx"].numpy(),
  386. tensor_dict["question"].numpy().decode("utf-8"),
  387. tensor_dict["sentence"].numpy().decode("utf-8"),
  388. str(tensor_dict["label"].numpy()),
  389. )
  390. def get_train_examples(self, data_dir):
  391. """See base class."""
  392. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  393. def get_dev_examples(self, data_dir):
  394. """See base class."""
  395. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  396. def get_test_examples(self, data_dir):
  397. """See base class."""
  398. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  399. def get_labels(self):
  400. """See base class."""
  401. return ["entailment", "not_entailment"]
  402. def _create_examples(self, lines, set_type):
  403. """Creates examples for the training, dev and test sets."""
  404. examples = []
  405. for i, line in enumerate(lines):
  406. if i == 0:
  407. continue
  408. guid = f"{set_type}-{line[0]}"
  409. text_a = line[1]
  410. text_b = line[2]
  411. label = None if set_type == "test" else line[-1]
  412. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  413. return examples
  414. class RteProcessor(DataProcessor):
  415. """Processor for the RTE data set (GLUE version)."""
  416. def __init__(self, *args, **kwargs):
  417. super().__init__(*args, **kwargs)
  418. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  419. def get_example_from_tensor_dict(self, tensor_dict):
  420. """See base class."""
  421. return InputExample(
  422. tensor_dict["idx"].numpy(),
  423. tensor_dict["sentence1"].numpy().decode("utf-8"),
  424. tensor_dict["sentence2"].numpy().decode("utf-8"),
  425. str(tensor_dict["label"].numpy()),
  426. )
  427. def get_train_examples(self, data_dir):
  428. """See base class."""
  429. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  430. def get_dev_examples(self, data_dir):
  431. """See base class."""
  432. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  433. def get_test_examples(self, data_dir):
  434. """See base class."""
  435. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  436. def get_labels(self):
  437. """See base class."""
  438. return ["entailment", "not_entailment"]
  439. def _create_examples(self, lines, set_type):
  440. """Creates examples for the training, dev and test sets."""
  441. examples = []
  442. for i, line in enumerate(lines):
  443. if i == 0:
  444. continue
  445. guid = f"{set_type}-{line[0]}"
  446. text_a = line[1]
  447. text_b = line[2]
  448. label = None if set_type == "test" else line[-1]
  449. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  450. return examples
  451. class WnliProcessor(DataProcessor):
  452. """Processor for the WNLI data set (GLUE version)."""
  453. def __init__(self, *args, **kwargs):
  454. super().__init__(*args, **kwargs)
  455. warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
  456. def get_example_from_tensor_dict(self, tensor_dict):
  457. """See base class."""
  458. return InputExample(
  459. tensor_dict["idx"].numpy(),
  460. tensor_dict["sentence1"].numpy().decode("utf-8"),
  461. tensor_dict["sentence2"].numpy().decode("utf-8"),
  462. str(tensor_dict["label"].numpy()),
  463. )
  464. def get_train_examples(self, data_dir):
  465. """See base class."""
  466. return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
  467. def get_dev_examples(self, data_dir):
  468. """See base class."""
  469. return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
  470. def get_test_examples(self, data_dir):
  471. """See base class."""
  472. return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
  473. def get_labels(self):
  474. """See base class."""
  475. return ["0", "1"]
  476. def _create_examples(self, lines, set_type):
  477. """Creates examples for the training, dev and test sets."""
  478. examples = []
  479. for i, line in enumerate(lines):
  480. if i == 0:
  481. continue
  482. guid = f"{set_type}-{line[0]}"
  483. text_a = line[1]
  484. text_b = line[2]
  485. label = None if set_type == "test" else line[-1]
  486. examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
  487. return examples
  488. glue_tasks_num_labels = {
  489. "cola": 2,
  490. "mnli": 3,
  491. "mrpc": 2,
  492. "sst-2": 2,
  493. "sts-b": 1,
  494. "qqp": 2,
  495. "qnli": 2,
  496. "rte": 2,
  497. "wnli": 2,
  498. }
  499. glue_processors = {
  500. "cola": ColaProcessor,
  501. "mnli": MnliProcessor,
  502. "mnli-mm": MnliMismatchedProcessor,
  503. "mrpc": MrpcProcessor,
  504. "sst-2": Sst2Processor,
  505. "sts-b": StsbProcessor,
  506. "qqp": QqpProcessor,
  507. "qnli": QnliProcessor,
  508. "rte": RteProcessor,
  509. "wnli": WnliProcessor,
  510. }
  511. glue_output_modes = {
  512. "cola": "classification",
  513. "mnli": "classification",
  514. "mnli-mm": "classification",
  515. "mrpc": "classification",
  516. "sst-2": "classification",
  517. "sts-b": "regression",
  518. "qqp": "classification",
  519. "qnli": "classification",
  520. "rte": "classification",
  521. "wnli": "classification",
  522. }