| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- from __future__ import annotations
- import logging
- from dataclasses import dataclass, field
- from transformers import TrainingArguments as TransformersTrainingArguments
- from transformers.training_args import ParallelMode
- from transformers.utils import ExplicitEnum
- logger = logging.getLogger(__name__)
- class BatchSamplers(ExplicitEnum):
- """
- Stores the acceptable string identifiers for batch samplers.
- The batch sampler is responsible for determining how samples are grouped into batches during training.
- Valid options are:
- - ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default
- PyTorch batch sampler.
- - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`,
- ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as:
- - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss`
- - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss`
- - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss`
- - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss`
- - :class:`~sentence_transformers.losses.MegaBatchMarginLoss`
- - :class:`~sentence_transformers.losses.GISTEmbedLoss`
- - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss`
- - ``BatchSamplers.GROUP_BY_LABEL``: Uses :class:`~sentence_transformers.sampler.GroupByLabelBatchSampler`,
- ensuring that each batch has 2+ samples from the same label. Recommended for losses that require multiple
- samples from the same label, such as:
- - :class:`~sentence_transformers.losses.BatchAllTripletLoss`
- - :class:`~sentence_transformers.losses.BatchHardSoftMarginTripletLoss`
- - :class:`~sentence_transformers.losses.BatchHardTripletLoss`
- - :class:`~sentence_transformers.losses.BatchSemiHardTripletLoss`
- If you want to use a custom batch sampler, you can create a new Trainer class that inherits from
- :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` and overrides the
- :meth:`~sentence_transformers.trainer.SentenceTransformerTrainer.get_batch_sampler` method. The
- method must return a class instance that supports ``__iter__`` and ``__len__`` methods. The former
- should yield a list of indices for each batch, and the latter should return the number of batches.
- Usage:
- ::
- from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
- from sentence_transformers.training_args import BatchSamplers
- from sentence_transformers.losses import MultipleNegativesRankingLoss
- from datasets import Dataset
- model = SentenceTransformer("microsoft/mpnet-base")
- train_dataset = Dataset.from_dict({
- "anchor": ["It's nice weather outside today.", "He drove to work."],
- "positive": ["It's so sunny.", "He took the car to the office."],
- })
- loss = MultipleNegativesRankingLoss(model)
- args = SentenceTransformerTrainingArguments(
- output_dir="checkpoints",
- batch_sampler=BatchSamplers.NO_DUPLICATES,
- )
- trainer = SentenceTransformerTrainer(
- model=model,
- args=args,
- train_dataset=train_dataset,
- loss=loss,
- )
- trainer.train()
- """
- BATCH_SAMPLER = "batch_sampler"
- NO_DUPLICATES = "no_duplicates"
- GROUP_BY_LABEL = "group_by_label"
- class MultiDatasetBatchSamplers(ExplicitEnum):
- """
- Stores the acceptable string identifiers for multi-dataset batch samplers.
- The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple
- datasets during training. Valid options are:
- - ``MultiDatasetBatchSamplers.ROUND_ROBIN``: Uses :class:`~sentence_transformers.sampler.RoundRobinBatchSampler`,
- which uses round-robin sampling from each dataset until one is exhausted.
- With this strategy, it's likely that not all samples from each dataset are used, but each dataset is sampled
- from equally.
- - ``MultiDatasetBatchSamplers.PROPORTIONAL``: **[default]** Uses :class:`~sentence_transformers.sampler.ProportionalBatchSampler`,
- which samples from each dataset in proportion to its size.
- With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
- Usage:
- ::
- from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
- from sentence_transformers.training_args import MultiDatasetBatchSamplers
- from sentence_transformers.losses import CoSENTLoss
- from datasets import Dataset, DatasetDict
- model = SentenceTransformer("microsoft/mpnet-base")
- train_general = Dataset.from_dict({
- "sentence_A": ["It's nice weather outside today.", "He drove to work."],
- "sentence_B": ["It's so sunny.", "He took the car to the bank."],
- "score": [0.9, 0.4],
- })
- train_medical = Dataset.from_dict({
- "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
- "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
- "score": [0.8, 0.6, 0.7],
- })
- train_legal = Dataset.from_dict({
- "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
- "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
- "score": [0.7, 0.8],
- })
- train_dataset = DatasetDict({
- "general": train_general,
- "medical": train_medical,
- "legal": train_legal,
- })
- loss = CoSENTLoss(model)
- args = SentenceTransformerTrainingArguments(
- output_dir="checkpoints",
- multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
- )
- trainer = SentenceTransformerTrainer(
- model=model,
- args=args,
- train_dataset=train_dataset,
- loss=loss,
- )
- trainer.train()
- """
- ROUND_ROBIN = "round_robin" # Round-robin sampling from each dataset
- PROPORTIONAL = "proportional" # Sample from each dataset in proportion to its size [default]
- @dataclass
- class SentenceTransformerTrainingArguments(TransformersTrainingArguments):
- """
- SentenceTransformerTrainingArguments extends :class:`~transformers.TrainingArguments` with additional arguments
- specific to Sentence Transformers. See :class:`~transformers.TrainingArguments` for the complete list of
- available arguments.
- Args:
- output_dir (`str`):
- The output directory where the model checkpoints will be written.
- batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*):
- The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options.
- Defaults to ``BatchSamplers.BATCH_SAMPLER``.
- multi_dataset_batch_sampler (Union[:class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`, `str`], *optional*):
- The multi-dataset batch sampler to use. See :class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`
- for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``.
- """
- batch_sampler: BatchSamplers | str = field(
- default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."}
- )
- multi_dataset_batch_sampler: MultiDatasetBatchSamplers | str = field(
- default=MultiDatasetBatchSamplers.PROPORTIONAL, metadata={"help": "The multi-dataset batch sampler to use."}
- )
- def __post_init__(self):
- super().__post_init__()
- self.batch_sampler = BatchSamplers(self.batch_sampler)
- self.multi_dataset_batch_sampler = MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)
- # The `compute_loss` method in `SentenceTransformerTrainer` is overridden to only compute the prediction loss,
- # so we set `prediction_loss_only` to `True` here to avoid
- self.prediction_loss_only = True
- # Disable broadcasting of buffers to avoid `RuntimeError: one of the variables needed for gradient computation
- # has been modified by an inplace operation.` when training with DDP & a BertModel-based model.
- self.ddp_broadcast_buffers = False
- if self.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
- # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
- # so we don't have to warn.
- if self.output_dir != "unused":
- logger.warning(
- "Currently using DataParallel (DP) for multi-gpu training, while DistributedDataParallel (DDP) is recommended for faster training. "
- "See https://sbert.net/docs/sentence_transformer/training/distributed.html for more information."
- )
- elif self.parallel_mode == ParallelMode.DISTRIBUTED and not self.dataloader_drop_last:
- # If output_dir is "unused", then this instance is created to compare training arguments vs the defaults,
- # so we don't have to warn.
- if self.output_dir != "unused":
- logger.warning(
- "When using DistributedDataParallel (DDP), it is recommended to set `dataloader_drop_last=True` to avoid hanging issues with an uneven last batch. "
- "Setting `dataloader_drop_last=True`."
- )
- self.dataloader_drop_last = True
|