trainer.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. from collections import OrderedDict
  5. from contextlib import nullcontext
  6. from typing import TYPE_CHECKING, Any, Callable
  7. import torch
  8. from torch import nn
  9. from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler
  10. from transformers import EvalPrediction, PreTrainedTokenizerBase, Trainer, TrainerCallback
  11. from transformers.data.data_collator import DataCollator
  12. from transformers.integrations import WandbCallback
  13. from transformers.trainer import TRAINING_ARGS_NAME
  14. from transformers.trainer_utils import EvalLoopOutput
  15. from sentence_transformers.data_collator import SentenceTransformerDataCollator
  16. from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
  17. from sentence_transformers.losses.CoSENTLoss import CoSENTLoss
  18. from sentence_transformers.model_card import ModelCardCallback
  19. from sentence_transformers.models.Transformer import Transformer
  20. from sentence_transformers.sampler import (
  21. DefaultBatchSampler,
  22. GroupByLabelBatchSampler,
  23. NoDuplicatesBatchSampler,
  24. ProportionalBatchSampler,
  25. RoundRobinBatchSampler,
  26. )
  27. from sentence_transformers.training_args import (
  28. BatchSamplers,
  29. MultiDatasetBatchSamplers,
  30. SentenceTransformerTrainingArguments,
  31. )
  32. from sentence_transformers.util import disable_logging, is_datasets_available, is_training_available
  33. if is_datasets_available():
  34. from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value
  35. logger = logging.getLogger(__name__)
  36. if TYPE_CHECKING:
  37. from sentence_transformers.SentenceTransformer import SentenceTransformer
  38. class SentenceTransformerTrainer(Trainer):
  39. """
  40. SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch
  41. based on the 🤗 Transformers :class:`~transformers.Trainer`.
  42. This trainer integrates support for various :class:`transformers.TrainerCallback` subclasses, such as:
  43. - :class:`~transformers.integrations.WandbCallback` to automatically log training metrics to W&B if `wandb` is installed
  44. - :class:`~transformers.integrations.TensorBoardCallback` to log training metrics to TensorBoard if `tensorboard` is accessible.
  45. - :class:`~transformers.integrations.CodeCarbonCallback` to track the carbon emissions of your model during training if `codecarbon` is installed.
  46. - Note: These carbon emissions will be included in your automatically generated model card.
  47. See the Transformers `Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>`_
  48. documentation for more information on the integrated callbacks and how to write your own callbacks.
  49. Args:
  50. model (:class:`~sentence_transformers.SentenceTransformer`, *optional*):
  51. The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
  52. args (:class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`, *optional*):
  53. The arguments to tweak for training. Will default to a basic instance of
  54. :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments` with the
  55. `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
  56. train_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
  57. The dataset to use for training. Must have a format accepted by your loss function, see
  58. `Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
  59. eval_dataset (Union[:class:`datasets.Dataset`, :class:`datasets.DatasetDict`, :class:`datasets.IterableDataset`, Dict[str, :class:`datasets.Dataset`]], *optional*):
  60. The dataset to use for evaluation. Must have a format accepted by your loss function, see
  61. `Training Overview > Dataset Format <../../../docs/sentence_transformer/training_overview.html#dataset-format>`_.
  62. loss (Optional[Union[:class:`torch.nn.Module`, Dict[str, :class:`torch.nn.Module`],\
  63. Callable[[:class:`~sentence_transformers.SentenceTransformer`], :class:`torch.nn.Module`],\
  64. Dict[str, Callable[[:class:`~sentence_transformers.SentenceTransformer`]]]], *optional*):
  65. The loss function to use for training. Can either be a loss class instance, a dictionary mapping dataset names to
  66. loss class instances, a function that returns a loss class instance given a model, or a dictionary mapping
  67. dataset names to functions that return a loss class instance given a model. In practice, the latter two
  68. are primarily used for hyper-parameter optimization. Will default to
  69. :class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
  70. evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
  71. List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
  72. The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
  73. or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
  74. are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
  75. wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially.
  76. callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
  77. A list of callbacks to customize the training loop. Will add those to the list of default callbacks
  78. detailed in [here](callback).
  79. If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
  80. optimizers (`Tuple[:class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler.LambdaLR`]`, *optional*, defaults to `(None, None)`):
  81. A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`torch.optim.AdamW`
  82. on your model and a scheduler given by :func:`transformers.get_linear_schedule_with_warmup` controlled by `args`.
  83. Important attributes:
  84. - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
  85. subclass.
  86. - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
  87. original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
  88. the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
  89. model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
  90. - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
  91. data parallelism, this means some of the model layers are split on different GPUs).
  92. - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
  93. to `False` if model parallel or deepspeed is used, or if the default
  94. `TrainingArguments.place_model_on_device` is overridden to return `False` .
  95. - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
  96. in `train`)
  97. """
  98. def __init__(
  99. self,
  100. model: SentenceTransformer | None = None,
  101. args: SentenceTransformerTrainingArguments = None,
  102. train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
  103. eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None,
  104. loss: nn.Module
  105. | dict[str, nn.Module]
  106. | Callable[[SentenceTransformer], torch.nn.Module]
  107. | dict[str, Callable[[SentenceTransformer], torch.nn.Module]]
  108. | None = None,
  109. evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None,
  110. data_collator: DataCollator | None = None,
  111. tokenizer: PreTrainedTokenizerBase | Callable | None = None,
  112. model_init: Callable[[], SentenceTransformer] | None = None,
  113. compute_metrics: Callable[[EvalPrediction], dict] | None = None,
  114. callbacks: list[TrainerCallback] | None = None,
  115. optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
  116. preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
  117. ) -> None:
  118. if not is_training_available():
  119. raise RuntimeError(
  120. "To train a SentenceTransformer model, you need to install the `accelerate` and `datasets` modules. "
  121. "You can do so with the `train` extra:\n"
  122. 'pip install -U "sentence-transformers[train]"'
  123. )
  124. if args is None:
  125. output_dir = "tmp_trainer"
  126. logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
  127. args = SentenceTransformerTrainingArguments(output_dir=output_dir)
  128. elif not isinstance(args, SentenceTransformerTrainingArguments):
  129. raise ValueError("Please use `TrainingArguments` imported from `sentence_transformers`.")
  130. if model is None:
  131. if model_init is not None:
  132. self.model_init = model_init
  133. model = self.call_model_init()
  134. else:
  135. raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
  136. else:
  137. if model_init is not None:
  138. logger.warning(
  139. "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
  140. " overwrite your model when calling the `train` method."
  141. )
  142. self.model_init = model_init
  143. if compute_metrics is not None:
  144. logger.warning(
  145. "`compute_metrics` is currently not compatible with the SentenceTransformerTrainer. Please use the "
  146. "`evaluator` argument instead for detailed evaluation metrics, or the `eval_dataset` argument for "
  147. "the evaluation loss."
  148. )
  149. # Get a dictionary of the default training arguments, so we can determine which arguments have been changed
  150. # for the model card
  151. default_args_dict = SentenceTransformerTrainingArguments(output_dir="unused").to_dict()
  152. # If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData,
  153. # then we can set it here for the model card regardless
  154. if args.hub_model_id and not model.model_card_data.model_id:
  155. model.model_card_data.set_model_id(args.hub_model_id)
  156. if tokenizer is None and isinstance(model.tokenizer, PreTrainedTokenizerBase):
  157. tokenizer = model.tokenizer
  158. if data_collator is None:
  159. data_collator = SentenceTransformerDataCollator(tokenize_fn=model.tokenize)
  160. for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
  161. if isinstance(dataset, IterableDataset) and dataset.column_names is None:
  162. sample = next(iter(dataset))
  163. naive_type_mapping = {str: "string", int: "int64", float: "float32", bool: "bool"}
  164. example_features = {
  165. key: Value(naive_type_mapping.get(type(value), "null")) for key, value in sample.items()
  166. }
  167. raise ValueError(
  168. f"The provided `{dataset_name}_dataset` must have Features. Specify them with e.g.:\n"
  169. f"{dataset_name}_dataset = {dataset_name}_dataset.cast(Features({example_features}))\n"
  170. "or by providing the Features to the IterableDataset initialization method. See the Datasets "
  171. "documentation for more information on dataset Features: "
  172. "https://huggingface.co/docs/datasets/en/about_dataset_features"
  173. )
  174. if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict):
  175. train_dataset = DatasetDict(train_dataset)
  176. if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict):
  177. eval_dataset = DatasetDict(eval_dataset)
  178. super().__init__(
  179. model=None if self.model_init else model,
  180. args=args,
  181. data_collator=data_collator,
  182. train_dataset=train_dataset,
  183. eval_dataset=eval_dataset,
  184. tokenizer=tokenizer,
  185. model_init=model_init,
  186. compute_metrics=compute_metrics,
  187. callbacks=callbacks,
  188. optimizers=optimizers,
  189. preprocess_logits_for_metrics=preprocess_logits_for_metrics,
  190. )
  191. # Every Sentence Transformer model can always return a loss, so we set this to True
  192. # to avoid having to specify it in the data collator or model's forward
  193. self.can_return_loss = True
  194. self.model: SentenceTransformer
  195. self.args: SentenceTransformerTrainingArguments
  196. self.data_collator: SentenceTransformerDataCollator
  197. # Set the W&B project via environment variables if it's not already set
  198. if any([isinstance(callback, WandbCallback) for callback in self.callback_handler.callbacks]):
  199. os.environ.setdefault("WANDB_PROJECT", "sentence-transformers")
  200. if loss is None:
  201. logger.info("No `loss` passed, using `losses.CoSENTLoss` as a default option.")
  202. loss = CoSENTLoss(self.model)
  203. if isinstance(loss, dict):
  204. self.loss = {dataset_name: self.prepare_loss(loss_fn, model) for dataset_name, loss_fn in loss.items()}
  205. for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]):
  206. if dataset is None:
  207. continue
  208. if not isinstance(dataset, dict):
  209. raise ValueError(
  210. f"If the provided `loss` is a dict, then the `{dataset_name}_dataset` must be a `DatasetDict`."
  211. )
  212. if missing := set(dataset.keys()) - set(loss.keys()):
  213. raise ValueError(
  214. f"If the provided `loss` is a dict, then all keys from the `{dataset_name}_dataset` dictionary must occur in `loss` also. "
  215. f"Currently, {sorted(missing)} occur{'s' if len(missing) == 1 else ''} in `{dataset_name}_dataset` but not in `loss`."
  216. )
  217. else:
  218. self.loss = self.prepare_loss(loss, model)
  219. # If evaluator is a list, we wrap it in a SequentialEvaluator
  220. if evaluator is not None and not isinstance(evaluator, SentenceEvaluator):
  221. evaluator = SequentialEvaluator(evaluator)
  222. self.evaluator = evaluator
  223. # Add a callback responsible for automatically tracking data required for the automatic model card generation
  224. model_card_callback = ModelCardCallback(self, default_args_dict)
  225. self.add_callback(model_card_callback)
  226. model_card_callback.on_init_end(self.args, self.state, self.control, self.model)
  227. def call_model_init(self, trial=None) -> SentenceTransformer:
  228. model = super().call_model_init(trial=trial)
  229. # If the Trainer already has a loss, then we'll want to override the model in the loss function
  230. if not hasattr(self, "loss"):
  231. return model
  232. # Multi-loss training:
  233. if isinstance(self.loss, dict):
  234. for key, loss_fn in self.loss.items():
  235. # If a loss function is not yet initialized, we initialize it here
  236. if not isinstance(loss_fn, torch.nn.Module):
  237. self.loss[key] = loss_fn(model)
  238. # Otherwise, we override the original model with the updated model in the loss function
  239. elif hasattr(loss_fn, "model"):
  240. self.loss = self.override_model_in_loss(self.loss, model)
  241. # Loss is a function accepting a model as an argument
  242. elif not isinstance(self.loss, torch.nn.Module):
  243. self.loss = self.loss(model)
  244. # Loss is an initialized torch.nn.Module
  245. elif hasattr(self.loss, "model"):
  246. self.loss = self.override_model_in_loss(self.loss, model)
  247. return model
  248. def override_model_in_loss(self, loss: torch.nn.Module, model: SentenceTransformer) -> torch.nn.Module:
  249. from sentence_transformers import SentenceTransformer
  250. for name, child in loss.named_children():
  251. if name == "model" and isinstance(child, SentenceTransformer):
  252. loss.model = model
  253. elif isinstance(child, torch.nn.Module):
  254. setattr(loss, name, self.override_model_in_loss(child, model))
  255. return loss
  256. def prepare_loss(
  257. self,
  258. loss: Callable[[SentenceTransformer], torch.nn.Module] | torch.nn.Module,
  259. model: SentenceTransformer,
  260. ) -> torch.nn.Module:
  261. if isinstance(loss, torch.nn.Module):
  262. return loss.to(model.device)
  263. return loss(model).to(model.device)
  264. def add_dataset_name_column(self, dataset_dict: DatasetDict) -> DatasetDict:
  265. for key, dataset in dataset_dict.items():
  266. if "dataset_name" not in dataset.column_names:
  267. dataset_dict[key] = dataset.add_column("dataset_name", [key] * len(dataset))
  268. return dataset_dict
  269. def compute_loss(
  270. self,
  271. model: SentenceTransformer,
  272. inputs: dict[str, torch.Tensor | Any],
  273. return_outputs: bool = False,
  274. ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]:
  275. """
  276. Computes the loss for the SentenceTransformer model.
  277. It uses ``self.loss`` to compute the loss, which can be a single loss function or a dictionary of loss functions
  278. for different datasets. If the loss is a dictionary, the dataset name is expected to be passed in the inputs
  279. under the key "dataset_name". This is done automatically in the ``add_dataset_name_column`` method.
  280. Note that even if ``return_outputs = True``, the outputs will be empty, as the SentenceTransformers losses do not
  281. return outputs.
  282. Args:
  283. model (SentenceTransformer): The SentenceTransformer model.
  284. inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model.
  285. return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False.
  286. Returns:
  287. Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.
  288. """
  289. dataset_name = inputs.pop("dataset_name", None)
  290. features, labels = self.collect_features(inputs)
  291. loss_fn = self.loss
  292. if isinstance(loss_fn, dict) and dataset_name:
  293. loss_fn = loss_fn[dataset_name]
  294. # Insert the wrapped (e.g. distributed or compiled) model into the loss function,
  295. # if the loss stores the model. Only called once per process
  296. if (
  297. model == self.model_wrapped
  298. and model != self.model # Only if the model is wrapped
  299. and hasattr(loss_fn, "model") # Only if the loss stores the model
  300. and loss_fn.model != model # Only if the wrapped model is not already stored
  301. ):
  302. loss_fn = self.override_model_in_loss(loss_fn, model)
  303. loss = loss_fn(features, labels)
  304. if return_outputs:
  305. # During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`.
  306. # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary.
  307. # This does not result in any problems, as the SentenceTransformerTrainingArguments sets
  308. # `prediction_loss_only=True` which means that the output is not used.
  309. return loss, {}
  310. return loss
  311. def collect_features(
  312. self, inputs: dict[str, torch.Tensor | Any]
  313. ) -> tuple[list[dict[str, torch.Tensor]], torch.Tensor | None]:
  314. """Turn the inputs from the dataloader into the separate model inputs & the labels.
  315. Example::
  316. >>> list(inputs.keys())
  317. ['return_loss', 'label', 'sentence_0_input_ids', 'sentence_0_token_type_ids', 'sentence_0_attention_mask', 'sentence_1_input_ids', 'sentence_1_token_type_ids', 'sentence_1_attention_mask']
  318. >>> features, labels = self.collect_features(inputs)
  319. >>> len(features)
  320. 2
  321. >>> list(features[0].keys())
  322. ['input_ids', 'token_type_ids', 'attention_mask']
  323. >>> list(features[1].keys())
  324. ['input_ids', 'token_type_ids', 'attention_mask']
  325. >>> torch.equal(labels, inputs["label"])
  326. True
  327. """
  328. # All inputs ending with `_input_ids` (Transformers), `_sentence_embedding` (BoW), `_pixel_values` (CLIPModel)
  329. # are considered to correspond to a feature
  330. features = []
  331. for column in inputs:
  332. if column.endswith("_input_ids"):
  333. prefix = column[: -len("input_ids")]
  334. elif column.endswith("_sentence_embedding"):
  335. prefix = column[: -len("sentence_embedding")]
  336. elif column.endswith("_pixel_values"):
  337. prefix = column[: -len("pixel_values")]
  338. else:
  339. continue
  340. features.append({key[len(prefix) :]: value for key, value in inputs.items() if key.startswith(prefix)})
  341. labels = inputs.get("label", None)
  342. return features, labels
  343. def evaluate(
  344. self,
  345. eval_dataset: Dataset | dict[str, Dataset] | None = None,
  346. ignore_keys: list[str] | None = None,
  347. metric_key_prefix: str = "eval",
  348. ) -> dict[str, float]:
  349. eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
  350. if isinstance(eval_dataset, DatasetDict) and isinstance(self.loss, dict):
  351. eval_dataset = self.add_dataset_name_column(eval_dataset)
  352. return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
  353. def evaluation_loop(
  354. self,
  355. dataloader: DataLoader,
  356. description: str,
  357. prediction_loss_only: bool | None = None,
  358. ignore_keys: list[str] | None = None,
  359. metric_key_prefix: str = "eval",
  360. ) -> EvalLoopOutput:
  361. output = super().evaluation_loop(
  362. dataloader=dataloader,
  363. description=description,
  364. prediction_loss_only=prediction_loss_only,
  365. ignore_keys=ignore_keys,
  366. metric_key_prefix=metric_key_prefix,
  367. )
  368. # If the evaluator is not defined, we can just return the output
  369. if self.evaluator is None:
  370. return output
  371. # If we are training and eval_dataset is a DatasetDict, then we should
  372. # 1) only run the evaluator for the first dataset
  373. # 2) prefix that only run as "eval", rather than e.g. "eval_multi_nli"
  374. if self.is_in_train and isinstance(self.eval_dataset, dict) and metric_key_prefix.startswith("eval_"):
  375. if metric_key_prefix[5:] == list(self.eval_dataset.keys())[0]:
  376. metric_key_prefix = "eval"
  377. else:
  378. return output
  379. with nullcontext() if self.is_local_process_zero() else disable_logging(logging.INFO):
  380. evaluator_metrics = self.evaluator(self.model)
  381. if not isinstance(evaluator_metrics, dict):
  382. evaluator_metrics = {"evaluator": evaluator_metrics}
  383. # Prefix all keys with metric_key_prefix + '_'
  384. for key in list(evaluator_metrics.keys()):
  385. if not key.startswith(f"{metric_key_prefix}_"):
  386. evaluator_metrics[f"{metric_key_prefix}_{key}"] = evaluator_metrics.pop(key)
  387. output.metrics.update(evaluator_metrics)
  388. return output
  389. def _load_best_model(self) -> None:
  390. # We want to ensure that this does not fail, and it may change if transformers updates how checkpoints are saved
  391. # Loading the best model is only supported for `transformers`-based models
  392. if not isinstance(self.model[0], Transformer):
  393. logger.info("Could not load best model, as the model is not a `transformers`-based model.")
  394. return
  395. try:
  396. if checkpoint := self.state.best_model_checkpoint:
  397. step = checkpoint.rsplit("-", 1)[-1]
  398. self.model.model_card_data.set_best_model_step(int(step))
  399. except Exception:
  400. pass
  401. # Override the model with the `transformers`-based auto_model, and restore the original SentenceTransformers
  402. # model with the loaded `transformers` model
  403. full_model = self.model
  404. self.model = self.model[0].auto_model
  405. try:
  406. return super()._load_best_model()
  407. finally:
  408. loaded_auto_model = self.model
  409. self.model = full_model
  410. self.model[0].auto_model = loaded_auto_model
  411. def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> bool:
  412. if overlap := set(dataset.column_names) & {"return_loss", "dataset_name"}:
  413. raise ValueError(
  414. f"The following column names are invalid in your {dataset_name + ' ' if dataset_name else ''}dataset: {list(overlap)}."
  415. " Avoid using these column names, as they are reserved for internal use."
  416. )
  417. def get_batch_sampler(
  418. self,
  419. dataset: Dataset,
  420. batch_size: int,
  421. drop_last: bool,
  422. valid_label_columns: list[str] | None = None,
  423. generator: torch.Generator | None = None,
  424. ) -> BatchSampler | None:
  425. """
  426. Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``.
  427. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler``
  428. to create the :class:`torch.utils.data.DataLoader`.
  429. .. note::
  430. Override this method to provide a custom batch sampler.
  431. Args:
  432. dataset (Dataset): The dataset to sample from.
  433. batch_size (int): Number of samples per batch.
  434. drop_last (bool): If True, drop the last incomplete batch if the dataset size
  435. is not divisible by the batch size.
  436. valid_label_columns (List[str]): List of column names to check for labels.
  437. The first column name from ``valid_label_columns`` found in the dataset will
  438. be used as the label column.
  439. generator (torch.Generator, optional): Optional random number generator for shuffling
  440. the indices.
  441. """
  442. if isinstance(dataset, IterableDataset):
  443. if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
  444. logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")
  445. return None
  446. if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES:
  447. return NoDuplicatesBatchSampler(
  448. dataset=dataset,
  449. batch_size=batch_size,
  450. drop_last=drop_last,
  451. valid_label_columns=valid_label_columns,
  452. generator=generator,
  453. )
  454. if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL:
  455. return GroupByLabelBatchSampler(
  456. dataset=dataset,
  457. batch_size=batch_size,
  458. drop_last=drop_last,
  459. valid_label_columns=valid_label_columns,
  460. )
  461. if self.args.batch_sampler == BatchSamplers.BATCH_SAMPLER:
  462. return DefaultBatchSampler(
  463. SubsetRandomSampler(range(len(dataset)), generator=generator),
  464. batch_size=batch_size,
  465. drop_last=drop_last,
  466. )
  467. def get_multi_dataset_batch_sampler(
  468. self,
  469. dataset: ConcatDataset,
  470. batch_samplers: list[BatchSampler],
  471. generator: torch.Generator | None = None,
  472. seed: int | None = 0,
  473. ) -> BatchSampler:
  474. """
  475. Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument
  476. in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the
  477. ``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`.
  478. .. note::
  479. Override this method to provide a custom multi-dataset batch sampler.
  480. Args:
  481. dataset (ConcatDataset): The concatenation of all datasets.
  482. batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset.
  483. generator (torch.Generator, optional): Optional random number generator for shuffling the indices.
  484. seed (int, optional): Optional seed for the random number generator
  485. """
  486. if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN:
  487. return RoundRobinBatchSampler(
  488. dataset=dataset,
  489. batch_samplers=batch_samplers,
  490. generator=generator,
  491. seed=seed,
  492. )
  493. if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.PROPORTIONAL:
  494. return ProportionalBatchSampler(
  495. dataset=dataset,
  496. batch_samplers=batch_samplers,
  497. generator=generator,
  498. seed=seed,
  499. )
  500. def get_train_dataloader(self) -> DataLoader:
  501. """
  502. Returns the training [`~torch.utils.data.DataLoader`].
  503. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
  504. training if necessary) otherwise.
  505. Subclass and override this method if you want to inject some custom behavior.
  506. """
  507. if self.train_dataset is None:
  508. raise ValueError("Trainer: training requires a train_dataset.")
  509. train_dataset = self.train_dataset
  510. data_collator = self.data_collator
  511. generator = torch.Generator()
  512. if self.args.seed:
  513. generator.manual_seed(self.args.seed)
  514. dataloader_params = {
  515. "collate_fn": data_collator,
  516. "num_workers": self.args.dataloader_num_workers,
  517. "pin_memory": self.args.dataloader_pin_memory,
  518. "persistent_workers": self.args.dataloader_persistent_workers,
  519. "prefetch_factor": self.args.dataloader_prefetch_factor,
  520. }
  521. if isinstance(train_dataset, IterableDataset):
  522. dataloader_params.update(
  523. {
  524. "batch_size": self.args.train_batch_size,
  525. "drop_last": self.args.dataloader_drop_last,
  526. }
  527. )
  528. elif isinstance(train_dataset, IterableDatasetDict):
  529. raise ValueError(
  530. "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
  531. )
  532. elif isinstance(train_dataset, DatasetDict):
  533. for dataset_name, dataset in train_dataset.items():
  534. self.validate_column_names(dataset, dataset_name=dataset_name)
  535. if isinstance(dataset, IterableDataset):
  536. raise ValueError(
  537. "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
  538. )
  539. if isinstance(self.loss, dict):
  540. train_dataset = self.add_dataset_name_column(train_dataset)
  541. batch_samplers = [
  542. self.get_batch_sampler(
  543. dataset,
  544. batch_size=self.args.train_batch_size,
  545. drop_last=self.args.dataloader_drop_last,
  546. valid_label_columns=data_collator.valid_label_columns,
  547. generator=generator,
  548. )
  549. for dataset in train_dataset.values()
  550. ]
  551. train_dataset = ConcatDataset(train_dataset.values())
  552. batch_sampler = self.get_multi_dataset_batch_sampler(
  553. dataset=train_dataset,
  554. batch_samplers=batch_samplers,
  555. generator=generator,
  556. seed=self.args.seed,
  557. )
  558. dataloader_params["batch_sampler"] = batch_sampler
  559. elif isinstance(train_dataset, Dataset):
  560. self.validate_column_names(train_dataset)
  561. batch_sampler = self.get_batch_sampler(
  562. train_dataset,
  563. batch_size=self.args.train_batch_size,
  564. drop_last=self.args.dataloader_drop_last,
  565. valid_label_columns=data_collator.valid_label_columns,
  566. generator=generator,
  567. )
  568. dataloader_params["batch_sampler"] = batch_sampler
  569. else:
  570. raise ValueError(
  571. "Unsupported `train_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for training."
  572. )
  573. # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
  574. # cause issues with multi-dataset training, so we want to set this to False.
  575. # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
  576. self.accelerator.even_batches = False
  577. self._train_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
  578. return self._train_dataloader
  579. def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
  580. """
  581. Returns the evaluation [`~torch.utils.data.DataLoader`].
  582. Subclass and override this method if you want to inject some custom behavior.
  583. Args:
  584. eval_dataset (`torch.utils.data.Dataset`, *optional*):
  585. If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
  586. by the `model.forward()` method are automatically removed. It must implement `__len__`.
  587. """
  588. if eval_dataset is None and self.eval_dataset is None:
  589. # Prevent errors if the evaluator is set but no eval_dataset is provided
  590. if self.evaluator is not None:
  591. return DataLoader([])
  592. raise ValueError("Trainer: evaluation requires an eval_dataset.")
  593. eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
  594. data_collator = self.data_collator
  595. generator = torch.Generator()
  596. if self.args.seed:
  597. generator.manual_seed(self.args.seed)
  598. dataloader_params = {
  599. "collate_fn": data_collator,
  600. "num_workers": self.args.dataloader_num_workers,
  601. "pin_memory": self.args.dataloader_pin_memory,
  602. "persistent_workers": self.args.dataloader_persistent_workers,
  603. "prefetch_factor": self.args.dataloader_prefetch_factor,
  604. }
  605. if isinstance(eval_dataset, IterableDataset):
  606. dataloader_params.update(
  607. {
  608. "batch_size": self.args.eval_batch_size,
  609. "drop_last": self.args.dataloader_drop_last,
  610. }
  611. )
  612. elif isinstance(eval_dataset, IterableDatasetDict):
  613. raise ValueError(
  614. "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
  615. )
  616. elif isinstance(eval_dataset, DatasetDict):
  617. for dataset in eval_dataset.values():
  618. if isinstance(dataset, IterableDataset):
  619. raise ValueError(
  620. "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
  621. )
  622. if isinstance(self.loss, dict):
  623. eval_dataset = self.add_dataset_name_column(eval_dataset)
  624. batch_samplers = [
  625. self.get_batch_sampler(
  626. dataset,
  627. batch_size=self.args.eval_batch_size,
  628. drop_last=self.args.dataloader_drop_last,
  629. valid_label_columns=data_collator.valid_label_columns,
  630. generator=generator,
  631. )
  632. for dataset in eval_dataset.values()
  633. ]
  634. eval_dataset = ConcatDataset(eval_dataset.values())
  635. batch_sampler = self.get_multi_dataset_batch_sampler(
  636. dataset=eval_dataset,
  637. batch_samplers=batch_samplers,
  638. generator=generator,
  639. seed=self.args.seed,
  640. )
  641. dataloader_params["batch_sampler"] = batch_sampler
  642. elif isinstance(eval_dataset, Dataset):
  643. batch_sampler = self.get_batch_sampler(
  644. eval_dataset,
  645. batch_size=self.args.eval_batch_size,
  646. drop_last=self.args.dataloader_drop_last,
  647. valid_label_columns=data_collator.valid_label_columns,
  648. generator=generator,
  649. )
  650. dataloader_params["batch_sampler"] = batch_sampler
  651. else:
  652. raise ValueError(
  653. "Unsupported `eval_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for evaluation."
  654. )
  655. # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
  656. # cause issues with multi-dataset training, so we want to set this to False during training.
  657. # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
  658. self.accelerator.even_batches = True
  659. return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
  660. def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
  661. """
  662. Returns the training [`~torch.utils.data.DataLoader`].
  663. Subclass and override this method if you want to inject some custom behavior.
  664. Args:
  665. test_dataset (`torch.utils.data.Dataset`, *optional*):
  666. The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
  667. `model.forward()` method are automatically removed. It must implement `__len__`.
  668. """
  669. data_collator = self.data_collator
  670. generator = torch.Generator()
  671. if self.args.seed:
  672. generator.manual_seed(self.args.seed)
  673. dataloader_params = {
  674. "collate_fn": data_collator,
  675. "num_workers": self.args.dataloader_num_workers,
  676. "pin_memory": self.args.dataloader_pin_memory,
  677. "persistent_workers": self.args.dataloader_persistent_workers,
  678. "prefetch_factor": self.args.dataloader_prefetch_factor,
  679. }
  680. if isinstance(test_dataset, IterableDataset):
  681. dataloader_params.update(
  682. {
  683. "batch_size": self.args.eval_batch_size,
  684. "drop_last": self.args.dataloader_drop_last,
  685. }
  686. )
  687. elif isinstance(test_dataset, IterableDatasetDict):
  688. raise ValueError(
  689. "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
  690. )
  691. elif isinstance(test_dataset, DatasetDict):
  692. for dataset_name, dataset in test_dataset.items():
  693. self.validate_column_names(dataset, dataset_name=dataset_name)
  694. if isinstance(dataset, IterableDataset):
  695. raise ValueError(
  696. "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
  697. )
  698. if isinstance(self.loss, dict):
  699. test_dataset = self.add_dataset_name_column(test_dataset)
  700. batch_samplers = [
  701. self.get_batch_sampler(
  702. dataset,
  703. batch_size=self.args.eval_batch_size,
  704. drop_last=self.args.dataloader_drop_last,
  705. valid_label_columns=data_collator.valid_label_columns,
  706. generator=generator,
  707. )
  708. for dataset in test_dataset.values()
  709. ]
  710. test_dataset = ConcatDataset(test_dataset.values())
  711. batch_sampler = self.get_multi_dataset_batch_sampler(
  712. dataset=test_dataset,
  713. batch_samplers=batch_samplers,
  714. generator=generator,
  715. seed=self.args.seed,
  716. )
  717. dataloader_params["batch_sampler"] = batch_sampler
  718. elif isinstance(test_dataset, Dataset):
  719. self.validate_column_names(test_dataset)
  720. batch_sampler = self.get_batch_sampler(
  721. test_dataset,
  722. batch_size=self.args.eval_batch_size,
  723. drop_last=self.args.dataloader_drop_last,
  724. valid_label_columns=data_collator.valid_label_columns,
  725. generator=generator,
  726. )
  727. dataloader_params["batch_sampler"] = batch_sampler
  728. else:
  729. raise ValueError(
  730. "Unsupported `test_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for testing."
  731. )
  732. # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
  733. # cause issues with multi-dataset training, so we want to set this to False.
  734. # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
  735. self.accelerator.even_batches = False
  736. self._train_dataloader = self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
  737. return self._train_dataloader
  738. def _save(self, output_dir: str | None = None, state_dict=None) -> None:
  739. # If we are executing this function, we are the process zero, so we don't check for that.
  740. output_dir = output_dir if output_dir is not None else self.args.output_dir
  741. os.makedirs(output_dir, exist_ok=True)
  742. logger.info(f"Saving model checkpoint to {output_dir}")
  743. self.model.save_pretrained(output_dir, safe_serialization=self.args.save_safetensors)
  744. if self.tokenizer is not None:
  745. self.tokenizer.save_pretrained(output_dir)
  746. # Good practice: save your training arguments together with the trained model
  747. torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
  748. def _load_from_checkpoint(self, checkpoint_path: str) -> None:
  749. from sentence_transformers import SentenceTransformer
  750. loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code)
  751. self.model.load_state_dict(loaded_model.state_dict())
  752. def create_model_card(
  753. self,
  754. language: str | None = None,
  755. license: str | None = None,
  756. tags: str | list[str] | None = None,
  757. model_name: str | None = None,
  758. finetuned_from: str | None = None,
  759. tasks: str | list[str] | None = None,
  760. dataset_tags: str | list[str] | None = None,
  761. dataset: str | list[str] | None = None,
  762. dataset_args: str | list[str] | None = None,
  763. **kwargs,
  764. ) -> None:
  765. if not self.is_world_process_zero():
  766. return
  767. if language:
  768. self.model.model_card_data.set_language(language)
  769. if license:
  770. self.model.model_card_data.set_license(license)
  771. if tags:
  772. self.model.model_card_data.add_tags(tags)
  773. self.model._create_model_card(self.args.output_dir, model_name=model_name)
  774. def get_optimizer_cls_and_kwargs(
  775. self, args: SentenceTransformerTrainingArguments, model: SentenceTransformer | None = None
  776. ) -> tuple[Any, Any]:
  777. """
  778. We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model`
  779. itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g.
  780. SoftmaxLoss (see #2872).
  781. This method requires `transformers` >= 4.43.0.
  782. """
  783. if isinstance(self.loss, dict):
  784. loss_model = nn.Sequential(OrderedDict(self.loss))
  785. else:
  786. loss_model = self.loss
  787. optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model)
  788. # If the kwargs were not overridden by the super() call, then we should override them here so that the potential
  789. # weights in the loss(es) can also be updated.
  790. if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()):
  791. decay_parameters = self.get_decay_parameter_names(loss_model)
  792. optimizer_kwargs["optimizer_dict"] = [
  793. {
  794. "params": [
  795. p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad)
  796. ],
  797. "weight_decay": self.args.weight_decay,
  798. },
  799. {
  800. "params": [
  801. p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
  802. ],
  803. "weight_decay": 0.0,
  804. },
  805. ]
  806. return optimizer_cls, optimizer_kwargs