| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927 |
- from __future__ import annotations
- import logging
- import os
- from collections import OrderedDict
- from contextlib import nullcontext
- from typing import TYPE_CHECKING, Any, Callable
- import torch
- from torch import nn
- from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler
- from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
- from transformers.data.data_collator import DataCollator
- from transformers.integrations import WandbCallback
- from transformers.trainer import TRAINING_ARGS_NAME
- from transformers.trainer_utils import EvalLoopOutput
- from sentence_transformers.data_collator import SentenceTransformerDataCollator
- from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
- from sentence_transformers.losses.CoSENTLoss import CoSENTLoss
- from sentence_transformers.model_card import ModelCardCallback
- from sentence_transformers.models.Transformer import Transformer
- from sentence_transformers.sampler import (
- DefaultBatchSampler,
- GroupByLabelBatchSampler,
- NoDuplicatesBatchSampler,
- ProportionalBatchSampler,
- RoundRobinBatchSampler,
- )
- from sentence_transformers.training_args import (
- BatchSamplers,
- MultiDatasetBatchSamplers,
- SentenceTransformerTrainingArguments,
- )
- from sentence_transformers.util import disable_logging, is_datasets_available, is_training_available
- if is_datasets_available():
- from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value
- logger = logging.getLogger(__name__)
- if TYPE_CHECKING:
- from sentence_transformers.SentenceTransformer import SentenceTransformer
- class SentenceTransformerTrainer(Trainer):
- """
- SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch
- based on the 🤗 Transformers :class:`~transformers.Trainer`.
- This trainer integrates support for various :class:`transformers.TrainerCallback` subclasses, such as:
- - :class:`~transformers.integrations.WandbCallback` to automatically log training metrics to W&B if `wandb` is installed
- - :class:`~transformers.integrations.TensorBoardCallback` to log training metrics to TensorBoard if `tensorboard` is accessible.
- - :class:`~transformers.integrations.CodeCarbonCallback` to track the carbon emissions of your model during training if `codecarbon` is installed.
- - Note: These carbon emissions will be included in your automatically generated model card.
- See the Transformers `Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>`_
- documentation for more information on the integrated callbacks and how to write your own callbacks.
- Args:
- model (:class:`~sentence_transformers.SentenceTransformer`, *optional*):
- The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
- args (:class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`, *optional*):
- The arguments to tweak for training. Will default to a basic instance of
- :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments` with the
- `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
- train_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
- The dataset to use for training. Must have a format accepted by your loss function, see
- `Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
- eval_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
- The dataset to use for evaluation. Must have a format accepted by your loss function, see
- `Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
- loss (Optional[Union[:class:`torch.nn.Module`, Dict[str, :class:`torch.nn.Module`],\
- Callable[[:class:`~sentence_transformers.SentenceTransformer`], :class:`torch.nn.Module`],\
- Dict[str, Callable[[:class:`~sentence_transformers.SentenceTransformer`]]]], *optional*):
- The loss function to use for training. Can either be a loss class instance, a dictionary mapping dataset names to
- loss class instances, a function that returns a loss class instance given a model, or a dictionary mapping
- dataset names to functions that return a loss class instance given a model. In practice, the latter two
- are primarily used for hyper-parameter optimization. Will default to
- :class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
- evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
- List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
- The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
- or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
- are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
- wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially.
- callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
- A list of callbacks to customize the training loop. Will add those to the list of default callbacks
- detailed in [here](callback).
- If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
- optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`torch.optim.AdamW`
- on your model and a scheduler given by :func:`transformers.get_linear_schedule_with_warmup` controlled by `args`.
- Important attributes:
- - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
- subclass.
- - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
- original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
- the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
- model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
- data parallelism, this means some of the model layers are split on different GPUs).
- - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
- to `False` if model parallel or deepspeed is used, or if the default
- `TrainingArguments.place_model_on_device` is overridden to return `False` .
- - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
- in `train`)
- """
- def __init__(
- self,
- model: SentenceTransformer | None = None,
- args: SentenceTransformerTrainingArguments = None,
- train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
- eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
- loss: nn.Module
- | dict[str, nn.Module]
- | Callable[[SentenceTransformer], torch.nn.Module]
- | dict[str, Callable[[SentenceTransformer], torch.nn.Module]]
- | None = None,
- evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None,
- data_collator: DataCollator | None = None,
- tokenizer: PreTrainedTokenizerBase | Callable | None = None,
- model_init: Callable[[], SentenceTransformer] | None = None,
- compute_metrics: Callable[[EvalPrediction], dict] | None = None,
- callbacks: list[TrainerCallback] | None = None,
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
- ) -> None:
- if not is_training_available():
- raise RuntimeError(
- "To train a SentenceTransformer model, you need to install the `accelerate` and `datasets` modules. "
- "You can do so with the `train` extra:\n"
- 'pip install -U "sentence-transformers[train]"'
- )
- if args is None:
- output_dir = "tmp_trainer"
- logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
- args = SentenceTransformerTrainingArguments(output_dir=output_dir)
- elif not isinstance(args, SentenceTransformerTrainingArguments):
- raise ValueError("Please use `TrainingArguments` imported from `sentence_transformers`.")
- if model is None:
- if model_init is not None:
- self.model_init = model_init
- model = self.call_model_init()
- else:
- raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
- else:
- if model_init is not None:
- logger.warning(
- "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
- " overwrite your model when calling the `train` method."
- )
- self.model_init = model_init
- if compute_metrics is not None:
- logger.warning(
- "`compute_metrics` is currently not compatible with the SentenceTransformerTrainer. Please use the "
- "`evaluator` argument instead for detailed evaluation metrics, or the `eval_dataset` argument for "
- "the evaluation loss."
- )
- # Get a dictionary of the default training arguments, so we can determine which arguments have been changed
- # for the model card
- default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict()
- # If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData,
- # then we can set it here for the model card regardless
- if args.hub_model_id and not model.model_card_data.model_id:
- model.model_card_data.set_model_id(args.hub_model_id)
- if tokenizer is None and isinstance(model.tokenizer, PreTrainedTokenizerBase):
- tokenizer = model.tokenizer
- if data_collator is None:
- data_collator = SentenceTransformerDataCollator(tokenize_fn=model.tokenize)
- for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
- if isinstance(dataset, IterableDataset) and dataset.column_names is None:
- sample = next(iter(dataset))
- naive_type_mapping = {str: "string", int: "int64", float: "float32", bool: "bool"}
- example_features = {
- key: Value(naive_type_mapping.get(type(value), "null")) for key, value in sample.items()
- }
- raise ValueError(
- f"The provided `{dataset_name}_dataset` must have Features. Specify them with e.g.:\n"
- f"{dataset_name}_dataset = {dataset_name}_dataset.cast(Features({example_features}))\n"
- "or by providing the Features to the IterableDataset initialization method. See the Datasets "
- "documentation for more information on dataset Features: "
- "https://huggingface.co/docs/datasets/en/about_dataset_features"
- )
- if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict):
- train_dataset = DatasetDict(train_dataset)
- if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
- eval_dataset = DatasetDict(eval_dataset)
- super().__init__(
- model=None if self.model_init else model,
- args=args,
- data_collator=data_collator,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- tokenizer=tokenizer,
- model_init=model_init,
- compute_metrics=compute_metrics,
- callbacks=callbacks,
- optimizers=optimizers,
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
- )
- # Every Sentence Transformer model can always return a loss, so we set this to True
- # to avoid having to specify it in the data collator or model's forward
- self.can_return_loss = True
- self.model: SentenceTransformer
- self.args: SentenceTransformerTrainingArguments
- self.data_collator: SentenceTransformerDataCollator
- # Set the W&B project via environment variables if it's not already set
- if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]):
- os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
- if loss is None:
- logger.info("No `loss` passed, using `losses.CoSENTLoss` as a default option.")
- loss = CoSENTLoss(self.model)
- if isinstance(loss, dict):
- self.loss = {dataset_name: self.prepare_loss(loss_fn, model) for dataset_name, loss_fn in loss.items()}
- for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
- if dataset is None:
- continue
- if not isinstance(dataset, dict):
- raise ValueError(
- f"If the provided `loss` is a dict, then the `{dataset_name}_dataset` must be a `DatasetDict`."
- )
- if missing := set(dataset.keys()) - set(loss.keys()):
- raise ValueError(
- f"If the provided `loss` is a dict, then all keys from the `{dataset_name}_dataset` dictionary must occur in `loss` also. "
- f"Currently, {sorted(missing)} occur{'s' if len(missing) == 1 else ''} in `{dataset_name}_dataset` but not in `loss`."
- )
- else:
- self.loss = self.prepare_loss(loss, model)
- # If evaluator is a list, we wrap it in a SequentialEvaluator
- if evaluator is not None and not isinstance(evaluator, SentenceEvaluator):
- evaluator = SequentialEvaluator(evaluator)
- self.evaluator = evaluator
- # Add a callback responsible for automatically tracking data required for the automatic model card generation
- model_card_callback = ModelCardCallback(self, default_args_dict)
- self.add_callback(model_card_callback)
- model_card_callback.on_init_end(self.args, self.state, self.control, self.model)
- def call_model_init(self, trial=None) -> SentenceTransformer:
- model = super().call_model_init(trial=trial)
- # If the Trainer already has a loss, then we'll want to override the model in the loss function
- if not hasattr(self, "loss"):
- return model
- # Multi-loss training:
- if isinstance(self.loss, dict):
- for key, loss_fn in self.loss.items():
- # If a loss function is not yet initialized, we initialize it here
- if not isinstance(loss_fn, torch.nn.Module):
- self.loss[key] = loss_fn(model)
- # Otherwise, we override the original model with the updated model in the loss function
- elif hasattr(loss_fn, "model"):
- self.loss = self.override_model_in_loss(self.loss, model)
- # Loss is a function accepting a model as an argument
- elif not isinstance(self.loss, torch.nn.Module):
- self.loss = self.loss(model)
- # Loss is an initialized torch.nn.Module
- elif hasattr(self.loss, "model"):
- self.loss = self.override_model_in_loss(self.loss, model)
- return model
- def override_model_in_loss(self, loss: torch.nn.Module, model: SentenceTransformer) -> torch.nn.Module:
- from sentence_transformers import SentenceTransformer
- for name, child in loss.named_children():
- if name == "model" and isinstance(child, SentenceTransformer):
- loss.model = model
- elif isinstance(child, torch.nn.Module):
- setattr(loss, name, self.override_model_in_loss(child, model))
- return loss
- def prepare_loss(
- self,
- loss: Callable[[SentenceTransformer], torch.nn.Module] | torch.nn.Module,
- model: SentenceTransformer,
- ) -> torch.nn.Module:
- if isinstance(loss, torch.nn.Module):
- return loss.to(model.device)
- return loss(model).to(model.device)
- def add_dataset_name_column(self, dataset_dict: DatasetDict) -> DatasetDict:
- for key, dataset in dataset_dict.items():
- if "dataset_name" not in dataset.column_names:
- dataset_dict[key] = dataset.add_column("dataset_name", [key] * len(dataset))
- return dataset_dict
- def compute_loss(
- self,
- model: SentenceTransformer,
- inputs: dict[str, torch.Tensor | Any],
- return_outputs: bool = False,
- ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
- """
- Computes the loss for the SentenceTransformer model.
- It uses ``self.loss`` to compute the loss, which can be a single loss function or a dictionary of loss functions
- for different datasets. If the loss is a dictionary, the dataset name is expected to be passed in the inputs
- under the key "dataset_name". This is done automatically in the ``add_dataset_name_column`` method.
- Note that even if ``return_outputs = True``, the outputs will be empty, as the SentenceTransformers losses do not
- return outputs.
- Args:
- model (SentenceTransformer): The SentenceTransformer model.
- inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
- return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
- Returns:
- Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
- """
- dataset_name = inputs.pop("dataset_name", None)
- features, labels = self.collect_features(inputs)
- loss_fn = self.loss
- if isinstance(loss_fn, dict) and dataset_name:
- loss_fn = loss_fn[dataset_name]
- # Insert the wrapped (e.g. distributed or compiled) model into the loss function,
- # if the loss stores the model. Only called once per process
- if (
- model == self.model_wrapped
- and model != self.model # Only if the model is wrapped
- and hasattr(loss_fn, "model") # Only if the loss stores the model
- and loss_fn.model != model # Only if the wrapped model is not already stored
- ):
- loss_fn = self.override_model_in_loss(loss_fn, model)
- loss = loss_fn(features, labels)
- if return_outputs:
- # During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`.
- # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary.
- # This does not result in any problems, as the SentenceTransformerTrainingArguments sets
- # `prediction_loss_only=True` which means that the output is not used.
- return loss, {}
- return loss
- def collect_features(
- self, inputs: dict[str, torch.Tensor | Any]
- ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor | None]:
- """Turn the inputs from the dataloader into the separate model inputs & the labels.
- Example::
- >>> list(inputs.keys())
- ['return_loss', 'label', 'sentence_0_input_ids', 'sentence_0_token_type_ids', 'sentence_0_attention_mask', 'sentence_1_input_ids', 'sentence_1_token_type_ids', 'sentence_1_attention_mask']
- >>> features, labels = self.collect_features(inputs)
- >>> len(features)
- 2
- >>> list(features[0].keys())
- ['input_ids', 'token_type_ids', 'attention_mask']
- >>> list(features[1].keys())
- ['input_ids', 'token_type_ids', 'attention_mask']
- >>> torch.equal(labels, inputs["label"])
- True
- """
- # All inputs ending with `_input_ids` (Transformers), `_sentence_embedding` (BoW), `_pixel_values` (CLIPModel)
- # are considered to correspond to a feature
- features = []
- for column in inputs:
- if column.endswith("_input_ids"):
- prefix = column[: -len("input_ids")]
- elif column.endswith("_sentence_embedding"):
- prefix = column[: -len("sentence_embedding")]
- elif column.endswith("_pixel_values"):
- prefix = column[: -len("pixel_values")]
- else:
- continue
- features.append({key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)})
- labels = inputs.get("label", None)
- return features, labels
- def evaluate(
- self,
- eval_dataset: Dataset | dict[str, Dataset] | None = None,
- ignore_keys: list[str] | None = None,
- metric_key_prefix: str = "eval",
- ) -> dict[str, float]:
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
- if isinstance(eval_dataset, DatasetDict) and isinstance(self.loss, dict):
- eval_dataset = self.add_dataset_name_column(eval_dataset)
- return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: bool | None = None,
- ignore_keys: list[str] | None = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- output = super().evaluation_loop(
- dataloader=dataloader,
- description=description,
- prediction_loss_only=prediction_loss_only,
- ignore_keys=ignore_keys,
- metric_key_prefix=metric_key_prefix,
- )
- # If the evaluator is not defined, we can just return the output
- if self.evaluator is None:
- return output
- # If we are training and eval_dataset is a DatasetDict, then we should
- # 1) only run the evaluator for the first dataset
- # 2) prefix that only run as "eval", rather than e.g. "eval_multi_nli"
- if self.is_in_train and isinstance(self.eval_dataset, dict) and metric_key_prefix.startswith("eval_"):
- if metric_key_prefix[5:] == list(self.eval_dataset.keys())[0]:
- metric_key_prefix = "eval"
- else:
- return output
- with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
- evaluator_metrics = self.evaluator(self.model)
- if not isinstance(evaluator_metrics, dict):
- evaluator_metrics = {"evaluator": evaluator_metrics}
- # Prefix all keys with metric_key_prefix + '_'
- for key in list(evaluator_metrics.keys()):
- if not key.startswith(f"{metric_key_prefix}_"):
- evaluator_metrics[f"{metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)
- output.metrics.update(evaluator_metrics)
- return output
- def _load_best_model(self) -> None:
- # We want to ensure that this does not fail, and it may change if transformers updates how checkpoints are saved
- # Loading the best model is only supported for `transformers`-based models
- if not isinstance(self.model[0], Transformer):
- logger.info("Could not load best model, as the model is not a `transformers`-based model.")
- return
- try:
- if checkpoint := self.state.best_model_checkpoint:
- step = checkpoint.rsplit("-", 1)[-1]
- self.model.model_card_data.set_best_model_step(int(step))
- except Exception:
- pass
- # Override the model with the `transformers`-based auto_model, and restore the original SentenceTransformers
- # model with the loaded `transformers` model
- full_model = self.model
- self.model = self.model[0].auto_model
- try:
- return super()._load_best_model()
- finally:
- loaded_auto_model = self.model
- self.model = full_model
- self.model[0].auto_model = loaded_auto_model
- def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> bool:
- if overlap := set(dataset.column_names) & {"return_loss", "dataset_name"}:
- raise ValueError(
- f"The following column names are invalid in your {dataset_name + ' ' if dataset_name else ''}dataset: {list(overlap)}."
- " Avoid using these column names, as they are reserved for internal use."
- )
- def get_batch_sampler(
- self,
- dataset: Dataset,
- batch_size: int,
- drop_last: bool,
- valid_label_columns: list[str] | None = None,
- generator: torch.Generator | None = None,
- ) -> BatchSampler | None:
- """
- Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``.
- This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler``
- to create the :class:`torch.utils.data.DataLoader`.
- .. note::
- Override this method to provide a custom batch sampler.
- Args:
- dataset (Dataset): The dataset to sample from.
- batch_size (int): Number of samples per batch.
- drop_last (bool): If True, drop the last incomplete batch if the dataset size
- is not divisible by the batch size.
- valid_label_columns (List[str]): List of column names to check for labels.
- The first column name from ``valid_label_columns`` found in the dataset will
- be used as the label column.
- generator (torch.Generator, optional): Optional random number generator for shuffling
- the indices.
- """
- if isinstance(dataset, IterableDataset):
- if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
- logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")
- return None
- if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES:
- return NoDuplicatesBatchSampler(
- dataset=dataset,
- batch_size=batch_size,
- drop_last=drop_last,
- valid_label_columns=valid_label_columns,
- generator=generator,
- )
- if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL:
- return GroupByLabelBatchSampler(
- dataset=dataset,
- batch_size=batch_size,
- drop_last=drop_last,
- valid_label_columns=valid_label_columns,
- )
- if self.args.batch_sampler == BatchSamplers.BATCH_SAMPLER:
- return DefaultBatchSampler(
- SubsetRandomSampler(range(len(dataset)), generator=generator),
- batch_size=batch_size,
- drop_last=drop_last,
- )
- def get_multi_dataset_batch_sampler(
- self,
- dataset: ConcatDataset,
- batch_samplers: list[BatchSampler],
- generator: torch.Generator | None = None,
- seed: int | None = 0,
- ) -> BatchSampler:
- """
- Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument
- in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the
- ``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`.
- .. note::
- Override this method to provide a custom multi-dataset batch sampler.
- Args:
- dataset (ConcatDataset): The concatenation of all datasets.
- batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset.
- generator (torch.Generator, optional): Optional random number generator for shuffling the indices.
- seed (int, optional): Optional seed for the random number generator
- """
- if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN:
- return RoundRobinBatchSampler(
- dataset=dataset,
- batch_samplers=batch_samplers,
- generator=generator,
- seed=seed,
- )
- if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.PROPORTIONAL:
- return ProportionalBatchSampler(
- dataset=dataset,
- batch_samplers=batch_samplers,
- generator=generator,
- seed=seed,
- )
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
- Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
- training if necessary) otherwise.
- Subclass and override this method if you want to inject some custom behavior.
- """
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- generator = torch.Generator()
- if self.args.seed:
- generator.manual_seed(self.args.seed)
- dataloader_params = {
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- "prefetch_factor": self.args.dataloader_prefetch_factor,
- }
- if isinstance(train_dataset, IterableDataset):
- dataloader_params.update(
- {
- "batch_size": self.args.train_batch_size,
- "drop_last": self.args.dataloader_drop_last,
- }
- )
- elif isinstance(train_dataset, IterableDatasetDict):
- raise ValueError(
- "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
- )
- elif isinstance(train_dataset, DatasetDict):
- for dataset_name, dataset in train_dataset.items():
- self.validate_column_names(dataset, dataset_name=dataset_name)
- if isinstance(dataset, IterableDataset):
- raise ValueError(
- "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
- )
- if isinstance(self.loss, dict):
- train_dataset = self.add_dataset_name_column(train_dataset)
- batch_samplers = [
- self.get_batch_sampler(
- dataset,
- batch_size=self.args.train_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- for dataset in train_dataset.values()
- ]
- train_dataset = ConcatDataset(train_dataset.values())
- batch_sampler = self.get_multi_dataset_batch_sampler(
- dataset=train_dataset,
- batch_samplers=batch_samplers,
- generator=generator,
- seed=self.args.seed,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- elif isinstance(train_dataset, Dataset):
- self.validate_column_names(train_dataset)
- batch_sampler = self.get_batch_sampler(
- train_dataset,
- batch_size=self.args.train_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- else:
- raise ValueError(
- "Unsupported `train_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for training."
- )
- # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
- # cause issues with multi-dataset training, so we want to set this to False.
- # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
- self.accelerator.even_batches = False
- self._train_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
- return self._train_dataloader
- def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- if eval_dataset is None and self.eval_dataset is None:
- # Prevent errors if the evaluator is set but no eval_dataset is provided
- if self.evaluator is not None:
- return DataLoader([])
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
- data_collator = self.data_collator
- generator = torch.Generator()
- if self.args.seed:
- generator.manual_seed(self.args.seed)
- dataloader_params = {
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- "prefetch_factor": self.args.dataloader_prefetch_factor,
- }
- if isinstance(eval_dataset, IterableDataset):
- dataloader_params.update(
- {
- "batch_size": self.args.eval_batch_size,
- "drop_last": self.args.dataloader_drop_last,
- }
- )
- elif isinstance(eval_dataset, IterableDatasetDict):
- raise ValueError(
- "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
- )
- elif isinstance(eval_dataset, DatasetDict):
- for dataset in eval_dataset.values():
- if isinstance(dataset, IterableDataset):
- raise ValueError(
- "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
- )
- if isinstance(self.loss, dict):
- eval_dataset = self.add_dataset_name_column(eval_dataset)
- batch_samplers = [
- self.get_batch_sampler(
- dataset,
- batch_size=self.args.eval_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- for dataset in eval_dataset.values()
- ]
- eval_dataset = ConcatDataset(eval_dataset.values())
- batch_sampler = self.get_multi_dataset_batch_sampler(
- dataset=eval_dataset,
- batch_samplers=batch_samplers,
- generator=generator,
- seed=self.args.seed,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- elif isinstance(eval_dataset, Dataset):
- batch_sampler = self.get_batch_sampler(
- eval_dataset,
- batch_size=self.args.eval_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- else:
- raise ValueError(
- "Unsupported `eval_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for evaluation."
- )
- # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
- # cause issues with multi-dataset training, so we want to set this to False during training.
- # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
- self.accelerator.even_batches = True
- return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
- def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- test_dataset (`torch.utils.data.Dataset`, *optional*):
- The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- data_collator = self.data_collator
- generator = torch.Generator()
- if self.args.seed:
- generator.manual_seed(self.args.seed)
- dataloader_params = {
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- "prefetch_factor": self.args.dataloader_prefetch_factor,
- }
- if isinstance(test_dataset, IterableDataset):
- dataloader_params.update(
- {
- "batch_size": self.args.eval_batch_size,
- "drop_last": self.args.dataloader_drop_last,
- }
- )
- elif isinstance(test_dataset, IterableDatasetDict):
- raise ValueError(
- "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
- )
- elif isinstance(test_dataset, DatasetDict):
- for dataset_name, dataset in test_dataset.items():
- self.validate_column_names(dataset, dataset_name=dataset_name)
- if isinstance(dataset, IterableDataset):
- raise ValueError(
- "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
- )
- if isinstance(self.loss, dict):
- test_dataset = self.add_dataset_name_column(test_dataset)
- batch_samplers = [
- self.get_batch_sampler(
- dataset,
- batch_size=self.args.eval_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- for dataset in test_dataset.values()
- ]
- test_dataset = ConcatDataset(test_dataset.values())
- batch_sampler = self.get_multi_dataset_batch_sampler(
- dataset=test_dataset,
- batch_samplers=batch_samplers,
- generator=generator,
- seed=self.args.seed,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- elif isinstance(test_dataset, Dataset):
- self.validate_column_names(test_dataset)
- batch_sampler = self.get_batch_sampler(
- test_dataset,
- batch_size=self.args.eval_batch_size,
- drop_last=self.args.dataloader_drop_last,
- valid_label_columns=data_collator.valid_label_columns,
- generator=generator,
- )
- dataloader_params["batch_sampler"] = batch_sampler
- else:
- raise ValueError(
- "Unsupported `test_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for testing."
- )
- # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
- # cause issues with multi-dataset training, so we want to set this to False.
- # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
- self.accelerator.even_batches = False
- self._train_dataloader = self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
- return self._train_dataloader
- def _save(self, output_dir: str | None = None, state_dict=None) -> None:
- # If we are executing this function, we are the process zero, so we don't check for that.
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- os.makedirs(output_dir, exist_ok=True)
- logger.info(f"Saving model checkpoint to {output_dir}")
- self.model.save_pretrained(output_dir, safe_serialization=self.args.save_safetensors)
- if self.tokenizer is not None:
- self.tokenizer.save_pretrained(output_dir)
- # Good practice: save your training arguments together with the trained model
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- def _load_from_checkpoint(self, checkpoint_path: str) -> None:
- from sentence_transformers import SentenceTransformer
- loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
- self.model.load_state_dict(loaded_model.state_dict())
- def create_model_card(
- self,
- language: str | None = None,
- license: str | None = None,
- tags: str | list[str] | None = None,
- model_name: str | None = None,
- finetuned_from: str | None = None,
- tasks: str | list[str] | None = None,
- dataset_tags: str | list[str] | None = None,
- dataset: str | list[str] | None = None,
- dataset_args: str | list[str] | None = None,
- **kwargs,
- ) -> None:
- if not self.is_world_process_zero():
- return
- if language:
- self.model.model_card_data.set_language(language)
- if license:
- self.model.model_card_data.set_license(license)
- if tags:
- self.model.model_card_data.add_tags(tags)
- self.model._create_model_card(self.args.output_dir, model_name=model_name)
- def get_optimizer_cls_and_kwargs(
- self, args: SentenceTransformerTrainingArguments, model: SentenceTransformer | None = None
- ) -> tuple[Any, Any]:
- """
- We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model`
- itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g.
- SoftmaxLoss (see #2872).
- This method requires `transformers` >= 4.43.0.
- """
- if isinstance(self.loss, dict):
- loss_model = nn.Sequential(OrderedDict(self.loss))
- else:
- loss_model = self.loss
- optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model)
- # If the kwargs were not overridden by the super() call, then we should override them here so that the potential
- # weights in the loss(es) can also be updated.
- if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()):
- decay_parameters = self.get_decay_parameter_names(loss_model)
- optimizer_kwargs["optimizer_dict"] = [
- {
- "params": [
- p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
- ],
- "weight_decay": self.args.weight_decay,
- },
- {
- "params": [
- p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
- ],
- "weight_decay": 0.0,
- },
- ]
- return optimizer_cls, optimizer_kwargs
|