ParallelSentencesDataset.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import annotations
  2. import gzip
  3. import logging
  4. import random
  5. from torch.utils.data import Dataset
  6. from sentence_transformers import SentenceTransformer
  7. from sentence_transformers.readers import InputExample
  8. logger = logging.getLogger(__name__)
  9. class ParallelSentencesDataset(Dataset):
  10. """
  11. This dataset reader can be used to read-in parallel sentences, i.e., it reads in a file with tab-seperated sentences with the same
  12. sentence in different languages. For example, the file can look like this (EN\tDE\tES):
  13. hello world hallo welt hola mundo
  14. second sentence zweiter satz segunda oración
  15. The sentence in the first column will be mapped to a sentence embedding using the given the embedder. For example,
  16. embedder is a mono-lingual sentence embedding method for English. The sentences in the other languages will also be
  17. mapped to this English sentence embedding.
  18. When getting a sample from the dataset, we get one sentence with the according sentence embedding for this sentence.
  19. teacher_model can be any class that implement an encode function. The encode function gets a list of sentences and
  20. returns a list of sentence embeddings
  21. """
  22. def __init__(
  23. self,
  24. student_model: SentenceTransformer,
  25. teacher_model: SentenceTransformer,
  26. batch_size: int = 8,
  27. use_embedding_cache: bool = True,
  28. ):
  29. """
  30. Parallel sentences dataset reader to train student model given a teacher model
  31. Args:
  32. student_model (SentenceTransformer): The student sentence embedding model that should be trained.
  33. teacher_model (SentenceTransformer): The teacher model that provides the sentence embeddings for the first column in the dataset file.
  34. batch_size (int, optional): The batch size for training. Defaults to 8.
  35. use_embedding_cache (bool, optional): Whether to use an embedding cache. Defaults to True.
  36. """
  37. self.student_model = student_model
  38. self.teacher_model = teacher_model
  39. self.datasets = []
  40. self.datasets_iterator = []
  41. self.datasets_tokenized = []
  42. self.dataset_indices = []
  43. self.copy_dataset_indices = []
  44. self.cache = []
  45. self.batch_size = batch_size
  46. self.use_embedding_cache = use_embedding_cache
  47. self.embedding_cache = {}
  48. self.num_sentences = 0
  49. def load_data(
  50. self, filepath: str, weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128
  51. ) -> None:
  52. """
  53. Reads in a tab-seperated .txt/.csv/.tsv or .gz file. The different columns contain the different translations of the sentence in the first column
  54. Args:
  55. filepath (str): Filepath to the file.
  56. weight (int, optional): If more than one dataset is loaded with load_data, specifies the frequency at which data should be sampled from this dataset. Defaults to 100.
  57. max_sentences (int, optional): Maximum number of lines to be read from the filepath. Defaults to None.
  58. max_sentence_length (int, optional): Skip the example if one of the sentences has more characters than max_sentence_length. Defaults to 128.
  59. Returns:
  60. None
  61. """
  62. logger.info("Load " + filepath)
  63. parallel_sentences = []
  64. with gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(
  65. filepath, encoding="utf8"
  66. ) as fIn:
  67. count = 0
  68. for line in fIn:
  69. sentences = line.strip().split("\t")
  70. if (
  71. max_sentence_length is not None
  72. and max_sentence_length > 0
  73. and max([len(sent) for sent in sentences]) > max_sentence_length
  74. ):
  75. continue
  76. parallel_sentences.append(sentences)
  77. count += 1
  78. if max_sentences is not None and max_sentences > 0 and count >= max_sentences:
  79. break
  80. self.add_dataset(
  81. parallel_sentences, weight=weight, max_sentences=max_sentences, max_sentence_length=max_sentence_length
  82. )
  83. def add_dataset(
  84. self,
  85. parallel_sentences: list[list[str]],
  86. weight: int = 100,
  87. max_sentences: int = None,
  88. max_sentence_length: int = 128,
  89. ):
  90. sentences_map = {}
  91. for sentences in parallel_sentences:
  92. if (
  93. max_sentence_length is not None
  94. and max_sentence_length > 0
  95. and max([len(sent) for sent in sentences]) > max_sentence_length
  96. ):
  97. continue
  98. source_sentence = sentences[0]
  99. if source_sentence not in sentences_map:
  100. sentences_map[source_sentence] = set()
  101. for sent in sentences:
  102. sentences_map[source_sentence].add(sent)
  103. if max_sentences is not None and max_sentences > 0 and len(sentences_map) >= max_sentences:
  104. break
  105. if len(sentences_map) == 0:
  106. return
  107. self.num_sentences += sum([len(sentences_map[sent]) for sent in sentences_map])
  108. dataset_id = len(self.datasets)
  109. self.datasets.append(list(sentences_map.items()))
  110. self.datasets_iterator.append(0)
  111. self.dataset_indices.extend([dataset_id] * weight)
  112. def generate_data(self):
  113. source_sentences_list = []
  114. target_sentences_list = []
  115. for data_idx in self.dataset_indices:
  116. src_sentence, trg_sentences = self.next_entry(data_idx)
  117. source_sentences_list.append(src_sentence)
  118. target_sentences_list.append(trg_sentences)
  119. # Generate embeddings
  120. src_embeddings = self.get_embeddings(source_sentences_list)
  121. for src_embedding, trg_sentences in zip(src_embeddings, target_sentences_list):
  122. for trg_sentence in trg_sentences:
  123. self.cache.append(InputExample(texts=[trg_sentence], label=src_embedding))
  124. random.shuffle(self.cache)
  125. def next_entry(self, data_idx):
  126. source, target_sentences = self.datasets[data_idx][self.datasets_iterator[data_idx]]
  127. self.datasets_iterator[data_idx] += 1
  128. if self.datasets_iterator[data_idx] >= len(self.datasets[data_idx]): # Restart iterator
  129. self.datasets_iterator[data_idx] = 0
  130. random.shuffle(self.datasets[data_idx])
  131. return source, target_sentences
  132. def get_embeddings(self, sentences):
  133. if not self.use_embedding_cache:
  134. return self.teacher_model.encode(
  135. sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
  136. )
  137. # Use caching
  138. new_sentences = []
  139. for sent in sentences:
  140. if sent not in self.embedding_cache:
  141. new_sentences.append(sent)
  142. if len(new_sentences) > 0:
  143. new_embeddings = self.teacher_model.encode(
  144. new_sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
  145. )
  146. for sent, embedding in zip(new_sentences, new_embeddings):
  147. self.embedding_cache[sent] = embedding
  148. return [self.embedding_cache[sent] for sent in sentences]
  149. def __len__(self):
  150. return self.num_sentences
  151. def __getitem__(self, idx):
  152. if len(self.cache) == 0:
  153. self.generate_data()
  154. return self.cache.pop()