training_args.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from __future__ import annotations
  2. import logging
  3. from dataclasses import dataclass, field
  4. from transformers import TrainingArguments as TransformersTrainingArguments
  5. from transformers.training_args import ParallelMode
  6. from transformers.utils import ExplicitEnum
  7. logger = logging.getLogger(__name__)
  8. class BatchSamplers(ExplicitEnum):
  9. """
  10. Stores the acceptable string identifiers for batch samplers.
  11. The batch sampler is responsible for determining how samples are grouped into batches during training.
  12. Valid options are:
  13. - ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default
  14. PyTorch batch sampler.
  15. - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`,
  16. ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as:
  17. - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
  18. - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
  19. - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss`
  20. - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss`
  21. - :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
  22. - :class:`~sentence_transformers.losses.GISTEmbedLoss`
  23. - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`
  24. - ``BatchSamplers.GROUP_BY_LABEL``: Uses :class:`~sentence_transformers.sampler.GroupByLabelBatchSampler`,
  25. ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple
  26. samples from the same label, such as:
  27. - :class:`~sentence_transformers.losses.BatchAllTripletLoss`
  28. - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss`
  29. - :class:`~sentence_transformers.losses.BatchHardTripletLoss`
  30. - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss`
  31. If you want to use a custom batch sampler, you can create a new Trainer class that inherits from
  32. :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` and overrides the
  33. :meth:`~sentence_transformers.trainer.SentenceTransformerTrainer.get_batch_sampler` method. The
  34. method must return a class instance that supports ``__iter__`` and ``__len__`` methods. The former
  35. should yield a list of indices for each batch, and the latter should return the number of batches.
  36. Usage:
  37. ::
  38. from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
  39. from sentence_transformers.training_args import BatchSamplers
  40. from sentence_transformers.losses import MultipleNegativesRankingLoss
  41. from datasets import Dataset
  42. model = SentenceTransformer("microsoft/mpnet-base")
  43. train_dataset = Dataset.from_dict({
  44. "anchor": ["It's nice weather outside today.", "He drove to work."],
  45. "positive": ["It's so sunny.", "He took the car to the office."],
  46. })
  47. loss = MultipleNegativesRankingLoss(model)
  48. args = SentenceTransformerTrainingArguments(
  49. output_dir="checkpoints",
  50. batch_sampler=BatchSamplers.NO_DUPLICATES,
  51. )
  52. trainer = SentenceTransformerTrainer(
  53. model=model,
  54. args=args,
  55. train_dataset=train_dataset,
  56. loss=loss,
  57. )
  58. trainer.train()
  59. """
  60. BATCH_SAMPLER = "batch_sampler"
  61. NO_DUPLICATES = "no_duplicates"
  62. GROUP_BY_LABEL = "group_by_label"
  63. class MultiDatasetBatchSamplers(ExplicitEnum):
  64. """
  65. Stores the acceptable string identifiers for multi-dataset batch samplers.
  66. The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple
  67. datasets during training. Valid options are:
  68. - ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Uses :class:`~sentence_transformers.sampler.RoundRobinBatchSampler`,
  69. which uses round-robin sampling from each dataset until one is exhausted.
  70. With this strategy, it's likely that not all samples from each dataset are used, but each dataset is sampled
  71. from equally.
  72. - ``MultiDatasetBatchSamplers.PROPORTIONAL``: **[default]** Uses :class:`~sentence_transformers.sampler.ProportionalBatchSampler`,
  73. which samples from each dataset in proportion to its size.
  74. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
  75. Usage:
  76. ::
  77. from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
  78. from sentence_transformers.training_args import MultiDatasetBatchSamplers
  79. from sentence_transformers.losses import CoSENTLoss
  80. from datasets import Dataset, DatasetDict
  81. model = SentenceTransformer("microsoft/mpnet-base")
  82. train_general = Dataset.from_dict({
  83. "sentence_A": ["It's nice weather outside today.", "He drove to work."],
  84. "sentence_B": ["It's so sunny.", "He took the car to the bank."],
  85. "score": [0.9, 0.4],
  86. })
  87. train_medical = Dataset.from_dict({
  88. "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
  89. "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
  90. "score": [0.8, 0.6, 0.7],
  91. })
  92. train_legal = Dataset.from_dict({
  93. "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
  94. "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
  95. "score": [0.7, 0.8],
  96. })
  97. train_dataset = DatasetDict({
  98. "general": train_general,
  99. "medical": train_medical,
  100. "legal": train_legal,
  101. })
  102. loss = CoSENTLoss(model)
  103. args = SentenceTransformerTrainingArguments(
  104. output_dir="checkpoints",
  105. multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
  106. )
  107. trainer = SentenceTransformerTrainer(
  108. model=model,
  109. args=args,
  110. train_dataset=train_dataset,
  111. loss=loss,
  112. )
  113. trainer.train()
  114. """
  115. ROUND_ROBIN = "round_robin" # Round-robin sampling from each dataset
  116. PROPORTIONAL = "proportional" # Sample from each dataset in proportion to its size [default]
  117. @dataclass
  118. class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
  119. """
  120. SentenceTransformerTrainingArguments extends :class:`~transformers.TrainingArguments` with additional arguments
  121. specific to Sentence Transformers. See :class:`~transformers.TrainingArguments` for the complete list of
  122. available arguments.
  123. Args:
  124. output_dir (`str`):
  125. The output directory where the model checkpoints will be written.
  126. batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*):
  127. The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options.
  128. Defaults to ``BatchSamplers.BATCH_SAMPLER``.
  129. multi_dataset_batch_sampler (Union[:class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`, `str`], *optional*):
  130. The multi-dataset batch sampler to use. See :class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`
  131. for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
  132. """
  133. batch_sampler: BatchSamplers | str = field(
  134. default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
  135. )
  136. multi_dataset_batch_sampler: MultiDatasetBatchSamplers | str = field(
  137. default=MultiDatasetBatchSamplers.PROPORTIONAL, metadata={"help": "The multi-dataset batch sampler to use."}
  138. )
  139. def __post_init__(self):
  140. super().__post_init__()
  141. self.batch_sampler = BatchSamplers(self.batch_sampler)
  142. self.multi_dataset_batch_sampler = MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)
  143. # The `compute_loss` method in `SentenceTransformerTrainer` is overridden to only compute the prediction loss,
  144. # so we set `prediction_loss_only` to `True` here to avoid
  145. self.prediction_loss_only = True
  146. # Disable broadcasting of buffers to avoid `RuntimeError: one of the variables needed for gradient computation
  147. # has been modified by an inplace operation.` when training with DDP & a BertModel-based model.
  148. self.ddp_broadcast_buffers = False
  149. if self.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
  150. # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
  151. # so we don't have to warn.
  152. if self.output_dir != "unused":
  153. logger.warning(
  154. "Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. "
  155. "See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information."
  156. )
  157. elif self.parallel_mode == ParallelMode.DISTRIBUTED and not self.dataloader_drop_last:
  158. # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
  159. # so we don't have to warn.
  160. if self.output_dir != "unused":
  161. logger.warning(
  162. "When using DistributedDataParallel (DDP), it is recommended to set `dataloader_drop_last=True` to avoid hanging issues with an uneven last batch. "
  163. "Setting `dataloader_drop_last=True`."
  164. )
  165. self.dataloader_drop_last = True