| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- from __future__ import annotations
- import gzip
- import logging
- import random
- from torch.utils.data import Dataset
- from sentence_transformers import SentenceTransformer
- from sentence_transformers.readers import InputExample
- logger = logging.getLogger(__name__)
- class ParallelSentencesDataset(Dataset):
- """
- This dataset reader can be used to read-in parallel sentences, i.e., it reads in a file with tab-seperated sentences with the same
- sentence in different languages. For example, the file can look like this (EN\tDE\tES):
- hello world hallo welt hola mundo
- second sentence zweiter satz segunda oración
- The sentence in the first column will be mapped to a sentence embedding using the given the embedder. For example,
- embedder is a mono-lingual sentence embedding method for English. The sentences in the other languages will also be
- mapped to this English sentence embedding.
- When getting a sample from the dataset, we get one sentence with the according sentence embedding for this sentence.
- teacher_model can be any class that implement an encode function. The encode function gets a list of sentences and
- returns a list of sentence embeddings
- """
- def __init__(
- self,
- student_model: SentenceTransformer,
- teacher_model: SentenceTransformer,
- batch_size: int = 8,
- use_embedding_cache: bool = True,
- ):
- """
- Parallel sentences dataset reader to train student model given a teacher model
- Args:
- student_model (SentenceTransformer): The student sentence embedding model that should be trained.
- teacher_model (SentenceTransformer): The teacher model that provides the sentence embeddings for the first column in the dataset file.
- batch_size (int, optional): The batch size for training. Defaults to 8.
- use_embedding_cache (bool, optional): Whether to use an embedding cache. Defaults to True.
- """
- self.student_model = student_model
- self.teacher_model = teacher_model
- self.datasets = []
- self.datasets_iterator = []
- self.datasets_tokenized = []
- self.dataset_indices = []
- self.copy_dataset_indices = []
- self.cache = []
- self.batch_size = batch_size
- self.use_embedding_cache = use_embedding_cache
- self.embedding_cache = {}
- self.num_sentences = 0
- def load_data(
- self, filepath: str, weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128
- ) -> None:
- """
- Reads in a tab-seperated .txt/.csv/.tsv or .gz file. The different columns contain the different translations of the sentence in the first column
- Args:
- filepath (str): Filepath to the file.
- weight (int, optional): If more than one dataset is loaded with load_data, specifies the frequency at which data should be sampled from this dataset. Defaults to 100.
- max_sentences (int, optional): Maximum number of lines to be read from the filepath. Defaults to None.
- max_sentence_length (int, optional): Skip the example if one of the sentences has more characters than max_sentence_length. Defaults to 128.
- Returns:
- None
- """
- logger.info("Load " + filepath)
- parallel_sentences = []
- with gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(
- filepath, encoding="utf8"
- ) as fIn:
- count = 0
- for line in fIn:
- sentences = line.strip().split("\t")
- if (
- max_sentence_length is not None
- and max_sentence_length > 0
- and max([len(sent) for sent in sentences]) > max_sentence_length
- ):
- continue
- parallel_sentences.append(sentences)
- count += 1
- if max_sentences is not None and max_sentences > 0 and count >= max_sentences:
- break
- self.add_dataset(
- parallel_sentences, weight=weight, max_sentences=max_sentences, max_sentence_length=max_sentence_length
- )
- def add_dataset(
- self,
- parallel_sentences: list[list[str]],
- weight: int = 100,
- max_sentences: int = None,
- max_sentence_length: int = 128,
- ):
- sentences_map = {}
- for sentences in parallel_sentences:
- if (
- max_sentence_length is not None
- and max_sentence_length > 0
- and max([len(sent) for sent in sentences]) > max_sentence_length
- ):
- continue
- source_sentence = sentences[0]
- if source_sentence not in sentences_map:
- sentences_map[source_sentence] = set()
- for sent in sentences:
- sentences_map[source_sentence].add(sent)
- if max_sentences is not None and max_sentences > 0 and len(sentences_map) >= max_sentences:
- break
- if len(sentences_map) == 0:
- return
- self.num_sentences += sum([len(sentences_map[sent]) for sent in sentences_map])
- dataset_id = len(self.datasets)
- self.datasets.append(list(sentences_map.items()))
- self.datasets_iterator.append(0)
- self.dataset_indices.extend([dataset_id] * weight)
- def generate_data(self):
- source_sentences_list = []
- target_sentences_list = []
- for data_idx in self.dataset_indices:
- src_sentence, trg_sentences = self.next_entry(data_idx)
- source_sentences_list.append(src_sentence)
- target_sentences_list.append(trg_sentences)
- # Generate embeddings
- src_embeddings = self.get_embeddings(source_sentences_list)
- for src_embedding, trg_sentences in zip(src_embeddings, target_sentences_list):
- for trg_sentence in trg_sentences:
- self.cache.append(InputExample(texts=[trg_sentence], label=src_embedding))
- random.shuffle(self.cache)
- def next_entry(self, data_idx):
- source, target_sentences = self.datasets[data_idx][self.datasets_iterator[data_idx]]
- self.datasets_iterator[data_idx] += 1
- if self.datasets_iterator[data_idx] >= len(self.datasets[data_idx]): # Restart iterator
- self.datasets_iterator[data_idx] = 0
- random.shuffle(self.datasets[data_idx])
- return source, target_sentences
- def get_embeddings(self, sentences):
- if not self.use_embedding_cache:
- return self.teacher_model.encode(
- sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
- )
- # Use caching
- new_sentences = []
- for sent in sentences:
- if sent not in self.embedding_cache:
- new_sentences.append(sent)
- if len(new_sentences) > 0:
- new_embeddings = self.teacher_model.encode(
- new_sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True
- )
- for sent, embedding in zip(new_sentences, new_embeddings):
- self.embedding_cache[sent] = embedding
- return [self.embedding_cache[sent] for sent in sentences]
- def __len__(self):
- return self.num_sentences
- def __getitem__(self, idx):
- if len(self.cache) == 0:
- self.generate_data()
- return self.cache.pop()
|