sampler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from __future__ import annotations
  2. import logging
  3. from collections import defaultdict
  4. from itertools import accumulate, cycle
  5. from typing import Any, Iterator
  6. import torch
  7. from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler
  8. from sentence_transformers.util import is_datasets_available
  9. if is_datasets_available():
  10. from datasets import Dataset
  11. logger = logging.getLogger(__name__)
  12. class SetEpochMixin:
  13. """
  14. Required for a BatchSampler as the Trainer will call set_epoch on the BatchSampler at the beginning of each epoch.
  15. The BatchSampler can then set the generator seed accordingly.
  16. """
  17. def __init__(self, *args, **kwargs) -> None:
  18. super().__init__(*args, **kwargs)
  19. self.epoch = 0
  20. def set_epoch(self, epoch: int) -> None:
  21. self.epoch = epoch
  22. class DefaultBatchSampler(SetEpochMixin, BatchSampler):
  23. """
  24. This sampler is the default batch sampler used in the SentenceTransformer library.
  25. It is equivalent to the PyTorch BatchSampler.
  26. Args:
  27. sampler (Sampler or Iterable): The sampler used for sampling elements from the dataset,
  28. such as SubsetRandomSampler.
  29. batch_size (int): Number of samples per batch.
  30. drop_last (bool): If True, drop the last incomplete batch if the dataset size
  31. is not divisible by the batch size.
  32. """
  33. class GroupByLabelBatchSampler(SetEpochMixin, BatchSampler):
  34. """
  35. This sampler groups samples by their labels and aims to create batches such that
  36. each batch contains samples where the labels are as homogeneous as possible.
  37. This sampler is meant to be used alongside the ``Batch...TripletLoss`` classes, which
  38. require that each batch contains at least 2 examples per label class.
  39. Recommended for:
  40. - :class:`~sentence_transformers.losses.BatchAllTripletLoss`
  41. - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss`
  42. - :class:`~sentence_transformers.losses.BatchHardTripletLoss`
  43. - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss`
  44. Args:
  45. dataset (Dataset): The dataset to sample from.
  46. batch_size (int): Number of samples per batch. Must be divisible by 2.
  47. drop_last (bool): If True, drop the last incomplete batch if the dataset size
  48. is not divisible by the batch size.
  49. valid_label_columns (List[str]): List of column names to check for labels.
  50. The first column name from ``valid_label_columns`` found in the dataset will
  51. be used as the label column.
  52. generator (torch.Generator, optional): Optional random number generator for shuffling
  53. the indices.
  54. seed (int, optional): Seed for the random number generator to ensure reproducibility.
  55. """
  56. def __init__(
  57. self,
  58. dataset: Dataset,
  59. batch_size: int,
  60. drop_last: bool,
  61. valid_label_columns: list[str] = None,
  62. generator: torch.Generator = None,
  63. seed: int = 0,
  64. ) -> None:
  65. super().__init__(dataset, batch_size, drop_last)
  66. self.dataset = dataset
  67. self.batch_size = batch_size
  68. self.drop_last = drop_last
  69. self.generator = generator
  70. self.seed = seed
  71. if self.batch_size % 2 == 1:
  72. raise ValueError("The batch size for `GroupByLabelBatchSampler` must be divisible by 2.")
  73. labels = self._determine_labels_to_use(dataset, valid_label_columns)
  74. groups = defaultdict(list)
  75. for sample_idx, label in enumerate(labels):
  76. groups[label].append(sample_idx)
  77. self.groups = {
  78. label: sample_indices[:num_samples]
  79. for label, sample_indices in groups.items()
  80. if (num_samples := len(sample_indices) // 2 * 2)
  81. }
  82. @staticmethod
  83. def _determine_labels_to_use(dataset: Dataset, valid_label_columns: list[str]) -> list[Any]:
  84. for column_name in valid_label_columns or []:
  85. if column_name in dataset.column_names:
  86. return dataset[column_name]
  87. raise ValueError(
  88. f"None of the valid_label_columns {valid_label_columns} are in the dataset, "
  89. f"which only has these columns: {dataset.column_names}."
  90. )
  91. def __iter__(self) -> Iterator[list[int]]:
  92. if self.generator and self.seed:
  93. self.generator.manual_seed(self.seed + self.epoch)
  94. partial_batch = []
  95. unique_labels = list(self.groups.keys())
  96. for label_idx in torch.randperm(len(self.groups), generator=self.generator):
  97. label = unique_labels[label_idx]
  98. samples = self.groups[label]
  99. partial_batch.extend(samples)
  100. while len(partial_batch) >= self.batch_size:
  101. yield partial_batch[: self.batch_size]
  102. partial_batch = partial_batch[self.batch_size :]
  103. if not self.drop_last and partial_batch:
  104. yield partial_batch
  105. class NoDuplicatesBatchSampler(SetEpochMixin, BatchSampler):
  106. def __init__(
  107. self,
  108. dataset: Dataset,
  109. batch_size: int,
  110. drop_last: bool,
  111. valid_label_columns: list[str] = [],
  112. generator: torch.Generator = None,
  113. seed: int = 0,
  114. ) -> None:
  115. """
  116. This sampler creates batches such that each batch contains samples where the values are unique,
  117. even across columns. This is useful when losses consider other samples in a batch to be in-batch
  118. negatives, and you want to ensure that the negatives are not duplicates of the anchor/positive sample.
  119. Recommended for:
  120. - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
  121. - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
  122. - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss`
  123. - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss`
  124. - :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
  125. - :class:`~sentence_transformers.losses.GISTEmbedLoss`
  126. - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`
  127. Args:
  128. dataset (Dataset): The dataset to sample from.
  129. batch_size (int): Number of samples per batch.
  130. drop_last (bool): If True, drop the last incomplete batch if the dataset size
  131. is not divisible by the batch size.
  132. valid_label_columns (List[str]): List of column names to check for labels.
  133. The first column name from ``valid_label_columns`` found in the dataset will
  134. be used as the label column.
  135. generator (torch.Generator, optional): Optional random number generator for shuffling
  136. the indices.
  137. seed (int, optional): Seed for the random number generator to ensure reproducibility.
  138. """
  139. super().__init__(dataset, batch_size, drop_last)
  140. if label_columns := set(dataset.column_names) & (set(valid_label_columns) | {"dataset_name"}):
  141. dataset = dataset.remove_columns(label_columns)
  142. self.dataset = dataset
  143. self.batch_size = batch_size
  144. self.drop_last = drop_last
  145. self.generator = generator
  146. self.seed = seed
  147. def __iter__(self) -> Iterator[list[int]]:
  148. """
  149. Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
  150. batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
  151. the batch indices and continue with the next batch.
  152. """
  153. if self.generator and self.seed:
  154. self.generator.manual_seed(self.seed + self.epoch)
  155. remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())
  156. while remaining_indices:
  157. batch_values = set()
  158. batch_indices = []
  159. for index in remaining_indices:
  160. sample_values = set(self.dataset[index].values())
  161. if sample_values & batch_values:
  162. continue
  163. batch_indices.append(index)
  164. if len(batch_indices) == self.batch_size:
  165. yield batch_indices
  166. break
  167. batch_values.update(sample_values)
  168. else:
  169. # NOTE: some indices might still have been ignored here
  170. if not self.drop_last:
  171. yield batch_indices
  172. remaining_indices -= set(batch_indices)
  173. def __len__(self) -> int:
  174. if self.drop_last:
  175. return len(self.dataset) // self.batch_size
  176. else:
  177. return (len(self.dataset) + self.batch_size - 1) // self.batch_size
  178. class RoundRobinBatchSampler(SetEpochMixin, BatchSampler):
  179. """
  180. Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted.
  181. With this sampler, it's unlikely that all samples from each dataset are used, but we do ensure that each dataset
  182. is sampled from equally.
  183. Args:
  184. dataset (ConcatDataset): A concatenation of multiple datasets.
  185. batch_samplers (List[BatchSampler]): A list of batch samplers, one for each dataset in the ConcatDataset.
  186. generator (torch.Generator, optional): A generator for reproducible sampling. Defaults to None.
  187. seed (int, optional): A seed for the generator. Defaults to None.
  188. """
  189. def __init__(
  190. self,
  191. dataset: ConcatDataset,
  192. batch_samplers: list[BatchSampler],
  193. generator: torch.Generator = None,
  194. seed: int = None,
  195. ) -> None:
  196. if len(dataset.datasets) != len(batch_samplers):
  197. raise ValueError("The number of batch samplers must match the number of datasets in the ConcatDataset.")
  198. super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
  199. self.dataset = dataset
  200. self.batch_samplers = batch_samplers
  201. self.generator = generator
  202. self.seed = seed
  203. def __iter__(self) -> Iterator[list[int]]:
  204. if self.generator and self.seed:
  205. self.generator.manual_seed(self.seed + self.epoch)
  206. num_samples = [len(dataset) for dataset in self.dataset.datasets]
  207. sample_offsets = [0] + list(accumulate(num_samples))
  208. batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
  209. for dataset_idx in cycle(range(len(batch_samplers))):
  210. sample_offset = sample_offsets[dataset_idx]
  211. try:
  212. yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]
  213. except StopIteration:
  214. # current iterator is apparently exhausted
  215. break
  216. def __len__(self) -> int:
  217. return min(len(sampler) for sampler in self.batch_samplers) * len(self.batch_samplers)
  218. class ProportionalBatchSampler(SetEpochMixin, BatchSampler):
  219. def __init__(
  220. self,
  221. dataset: ConcatDataset,
  222. batch_samplers: list[BatchSampler],
  223. generator: torch.Generator,
  224. seed: int,
  225. ) -> None:
  226. """
  227. Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously.
  228. With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently.
  229. Args:
  230. dataset (ConcatDataset): A concatenation of multiple datasets.
  231. batch_samplers (List[BatchSampler]): A list of batch samplers, one for each dataset in the ConcatDataset.
  232. generator (torch.Generator, optional): A generator for reproducible sampling. Defaults to None.
  233. seed (int, optional): A seed for the generator. Defaults to None.
  234. """
  235. super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
  236. self.dataset = dataset
  237. self.batch_samplers = batch_samplers
  238. self.generator = generator
  239. self.seed = seed
  240. def __iter__(self) -> Iterator[list[int]]:
  241. self.generator.manual_seed(self.seed + self.epoch)
  242. num_samples = [len(dataset) for dataset in self.dataset.datasets]
  243. sample_offsets = [0] + list(accumulate(num_samples))
  244. num_batches = [len(sampler) for sampler in self.batch_samplers]
  245. dataset_indices = [idx for idx, length in enumerate(num_batches) for _ in range(length)]
  246. dataset_idx_sampler = SubsetRandomSampler(dataset_indices, generator=self.generator)
  247. batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
  248. for dataset_idx in dataset_idx_sampler:
  249. sample_offset = sample_offsets[dataset_idx]
  250. try:
  251. yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]
  252. except StopIteration:
  253. continue
  254. def __len__(self) -> int:
  255. return sum([len(sampler) for sampler in self.batch_samplers])