DenoisingAutoEncoderDataset.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from __future__ import annotations
  2. import numpy as np
  3. from torch.utils.data import Dataset
  4. from transformers.utils.import_utils import NLTK_IMPORT_ERROR, is_nltk_available
  5. from sentence_transformers.readers.InputExample import InputExample
  6. class DenoisingAutoEncoderDataset(Dataset):
  7. """
  8. The DenoisingAutoEncoderDataset returns InputExamples in the format: texts=[noise_fn(sentence), sentence]
  9. It is used in combination with the DenoisingAutoEncoderLoss: Here, a decoder tries to re-construct the
  10. sentence without noise.
  11. Args:
  12. sentences: A list of sentences
  13. noise_fn: A noise function: Given a string, it returns a string
  14. with noise, e.g. deleted words
  15. """
  16. def __init__(self, sentences: list[str], noise_fn=lambda s: DenoisingAutoEncoderDataset.delete(s)):
  17. if not is_nltk_available():
  18. raise ImportError(NLTK_IMPORT_ERROR.format(self.__class__.__name__))
  19. self.sentences = sentences
  20. self.noise_fn = noise_fn
  21. def __getitem__(self, item):
  22. sent = self.sentences[item]
  23. return InputExample(texts=[self.noise_fn(sent), sent])
  24. def __len__(self):
  25. return len(self.sentences)
  26. # Deletion noise.
  27. @staticmethod
  28. def delete(text, del_ratio=0.6):
  29. from nltk import word_tokenize
  30. from nltk.tokenize.treebank import TreebankWordDetokenizer
  31. words = word_tokenize(text)
  32. n = len(words)
  33. if n == 0:
  34. return text
  35. keep_or_not = np.random.rand(n) > del_ratio
  36. if sum(keep_or_not) == 0:
  37. keep_or_not[np.random.choice(n)] = True # guarantee that at least one word remains
  38. words_processed = TreebankWordDetokenizer().detokenize(np.array(words)[keep_or_not])
  39. return words_processed