NoDuplicatesDataLoader.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from __future__ import annotations
  2. import math
  3. import random
  4. class NoDuplicatesDataLoader:
  5. def __init__(self, train_examples, batch_size):
  6. """
  7. A special data loader to be used with MultipleNegativesRankingLoss.
  8. The data loader ensures that there are no duplicate sentences within the same batch
  9. """
  10. self.batch_size = batch_size
  11. self.data_pointer = 0
  12. self.collate_fn = None
  13. self.train_examples = train_examples
  14. random.shuffle(self.train_examples)
  15. def __iter__(self):
  16. for _ in range(self.__len__()):
  17. batch = []
  18. texts_in_batch = set()
  19. while len(batch) < self.batch_size:
  20. example = self.train_examples[self.data_pointer]
  21. valid_example = True
  22. for text in example.texts:
  23. if text.strip().lower() in texts_in_batch:
  24. valid_example = False
  25. break
  26. if valid_example:
  27. batch.append(example)
  28. for text in example.texts:
  29. texts_in_batch.add(text.strip().lower())
  30. self.data_pointer += 1
  31. if self.data_pointer >= len(self.train_examples):
  32. self.data_pointer = 0
  33. random.shuffle(self.train_examples)
  34. yield self.collate_fn(batch) if self.collate_fn is not None else batch
  35. def __len__(self):
  36. return math.floor(len(self.train_examples) / self.batch_size)