| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727 |
- # coding=utf-8
- # Copyright 2020-present the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Callbacks to use with the Trainer class and customize the training loop.
- """
- import dataclasses
- import json
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Union
- import numpy as np
- from tqdm.auto import tqdm
- from .trainer_utils import IntervalStrategy, has_length
- from .training_args import TrainingArguments
- from .utils import logging
- logger = logging.get_logger(__name__)
- @dataclass
- class TrainerState:
- """
- A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
- and passed to the [`TrainerCallback`].
- <Tip>
- In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
- step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
- step requires going through *n* batches.
- </Tip>
- Args:
- epoch (`float`, *optional*):
- Only set during training, will represent the epoch the training is at (the decimal part being the
- percentage of the current epoch completed).
- global_step (`int`, *optional*, defaults to 0):
- During training, represents the number of update steps completed.
- max_steps (`int`, *optional*, defaults to 0):
- The number of update steps to do during the current training.
- logging_steps (`int`, *optional*, defaults to 500):
- Log every X updates steps
- eval_steps (`int`, *optional*):
- Run an evaluation every X steps.
- save_steps (`int`, *optional*, defaults to 500):
- Save checkpoint every X updates steps.
- train_batch_size (`int`, *optional*):
- The batch size for the training dataloader. Only needed when
- `auto_find_batch_size` has been used.
- num_input_tokens_seen (`int`, *optional*, defaults to 0):
- The number of tokens seen during training (number of input tokens, not the number of prediction tokens).
- total_flos (`float`, *optional*, defaults to 0):
- The total number of floating operations done by the model since the beginning of training (stored as floats
- to avoid overflow).
- log_history (`List[Dict[str, float]]`, *optional*):
- The list of logs done since the beginning of training.
- best_metric (`float`, *optional*):
- When tracking the best model, the value of the best metric encountered so far.
- best_model_checkpoint (`str`, *optional*):
- When tracking the best model, the value of the name of the checkpoint for the best model encountered so
- far.
- is_local_process_zero (`bool`, *optional*, defaults to `True`):
- Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
- several machines) main process.
- is_world_process_zero (`bool`, *optional*, defaults to `True`):
- Whether or not this process is the global main process (when training in a distributed fashion on several
- machines, this is only going to be `True` for one process).
- is_hyper_param_search (`bool`, *optional*, defaults to `False`):
- Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
- impact the way data will be logged in TensorBoard.
- stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
- Callbacks attached to the `Trainer` that should have their states be saved or restored.
- Relevent callbacks should implement a `state` and `from_state` function.
- """
- epoch: Optional[float] = None
- global_step: int = 0
- max_steps: int = 0
- logging_steps: int = 500
- eval_steps: int = 500
- save_steps: int = 500
- train_batch_size: int = None
- num_train_epochs: int = 0
- num_input_tokens_seen: int = 0
- total_flos: float = 0
- log_history: List[Dict[str, float]] = None
- best_metric: Optional[float] = None
- best_model_checkpoint: Optional[str] = None
- is_local_process_zero: bool = True
- is_world_process_zero: bool = True
- is_hyper_param_search: bool = False
- trial_name: str = None
- trial_params: Dict[str, Union[str, float, int, bool]] = None
- stateful_callbacks: List["TrainerCallback"] = None
- def __post_init__(self):
- if self.log_history is None:
- self.log_history = []
- if self.stateful_callbacks is None:
- self.stateful_callbacks = {}
- elif isinstance(self.stateful_callbacks, dict):
- # We are loading the callbacks in from the state file, no need to process them
- pass
- else:
- # Saveable callbacks get stored as dict of kwargs
- stateful_callbacks = {}
- for callback in self.stateful_callbacks:
- if not isinstance(callback, (ExportableState)):
- raise TypeError(
- f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
- )
- name = callback.__class__.__name__
- if name in stateful_callbacks:
- # We can have multiple versions of the same callback
- # if so, we store them as a list of states to restore
- if not isinstance(stateful_callbacks[name], list):
- stateful_callbacks[name] = [stateful_callbacks[name]]
- stateful_callbacks[name].append(callback.state())
- else:
- stateful_callbacks[name] = callback.state()
- self.stateful_callbacks = stateful_callbacks
- def save_to_json(self, json_path: str):
- """Save the content of this instance in JSON format inside `json_path`."""
- json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
- with open(json_path, "w", encoding="utf-8") as f:
- f.write(json_string)
- @classmethod
- def load_from_json(cls, json_path: str):
- """Create an instance from the content of `json_path`."""
- with open(json_path, "r", encoding="utf-8") as f:
- text = f.read()
- return cls(**json.loads(text))
- class ExportableState:
- """
- A class for objects that include the ability to have its state
- be saved during `Trainer._save_checkpoint` and loaded back in during
- `Trainer._load_from_checkpoint`.
- These must implement a `state` function that gets called during the respective
- Trainer function call. It should only include parameters and attributes needed to
- recreate the state at a particular time, to avoid utilizing pickle/maintain standard
- file IO writing.
- Example:
- ```python
- class EarlyStoppingCallback(TrainerCallback, ExportableState):
- def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
- self.early_stopping_patience = early_stopping_patience
- self.early_stopping_threshold = early_stopping_threshold
- # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
- self.early_stopping_patience_counter = 0
- def state(self) -> dict:
- return {
- "args": {
- "early_stopping_patience": self.early_stopping_patience,
- "early_stopping_threshold": self.early_stopping_threshold,
- },
- "attributes": {
- "early_stopping_patience_counter": self.early_stopping_patience_counter,
- }
- }
- ```"""
- def state(self) -> dict:
- raise NotImplementedError("You must implement a `state` function to utilize this class.")
- @classmethod
- def from_state(cls, state):
- instance = cls(**state["args"])
- for k, v in state["attributes"].items():
- setattr(instance, k, v)
- return instance
- @dataclass
- class TrainerControl(ExportableState):
- """
- A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
- switches in the training loop.
- Args:
- should_training_stop (`bool`, *optional*, defaults to `False`):
- Whether or not the training should be interrupted.
- If `True`, this variable will not be set back to `False`. The training will just stop.
- should_epoch_stop (`bool`, *optional*, defaults to `False`):
- Whether or not the current epoch should be interrupted.
- If `True`, this variable will be set back to `False` at the beginning of the next epoch.
- should_save (`bool`, *optional*, defaults to `False`):
- Whether or not the model should be saved at this step.
- If `True`, this variable will be set back to `False` at the beginning of the next step.
- should_evaluate (`bool`, *optional*, defaults to `False`):
- Whether or not the model should be evaluated at this step.
- If `True`, this variable will be set back to `False` at the beginning of the next step.
- should_log (`bool`, *optional*, defaults to `False`):
- Whether or not the logs should be reported at this step.
- If `True`, this variable will be set back to `False` at the beginning of the next step.
- """
- should_training_stop: bool = False
- should_epoch_stop: bool = False
- should_save: bool = False
- should_evaluate: bool = False
- should_log: bool = False
- def _new_training(self):
- """Internal method that resets the variable for a new training."""
- self.should_training_stop = False
- def _new_epoch(self):
- """Internal method that resets the variable for a new epoch."""
- self.should_epoch_stop = False
- def _new_step(self):
- """Internal method that resets the variable for a new step."""
- self.should_save = False
- self.should_evaluate = False
- self.should_log = False
- def state(self) -> dict:
- return {
- "args": {
- "should_training_stop": self.should_training_stop,
- "should_epoch_stop": self.should_epoch_stop,
- "should_save": self.should_save,
- "should_evaluate": self.should_evaluate,
- "should_log": self.should_log,
- },
- "attributes": {},
- }
- class TrainerCallback:
- # no-format
- """
- A class for objects that will inspect the state of the training loop at some events and take some decisions. At
- each of those events the following arguments are available:
- Args:
- args ([`TrainingArguments`]):
- The training arguments used to instantiate the [`Trainer`].
- state ([`TrainerState`]):
- The current state of the [`Trainer`].
- control ([`TrainerControl`]):
- The object that is returned to the [`Trainer`] and can be used to make some decisions.
- model ([`PreTrainedModel`] or `torch.nn.Module`):
- The model being trained.
- tokenizer ([`PreTrainedTokenizer`]):
- The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`.
- processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
- The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
- optimizer (`torch.optim.Optimizer`):
- The optimizer used for the training steps.
- lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
- The scheduler used for setting the learning rate.
- train_dataloader (`torch.utils.data.DataLoader`, *optional*):
- The current dataloader used for training.
- eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
- The current dataloader used for evaluation.
- metrics (`Dict[str, float]`):
- The metrics computed by the last evaluation phase.
- Those are only accessible in the event `on_evaluate`.
- logs (`Dict[str, float]`):
- The values to log.
- Those are only accessible in the event `on_log`.
- The `control` object is the only one that can be changed by the callback, in which case the event that changes it
- should return the modified version.
- The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
- You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
- simple [`~transformers.PrinterCallback`].
- Example:
- ```python
- class PrinterCallback(TrainerCallback):
- def on_log(self, args, state, control, logs=None, **kwargs):
- _ = logs.pop("total_flos", None)
- if state.is_local_process_zero:
- print(logs)
- ```"""
- def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the end of the initialization of the [`Trainer`].
- """
- pass
- def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the beginning of training.
- """
- pass
- def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the end of training.
- """
- pass
- def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the beginning of an epoch.
- """
- pass
- def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the end of an epoch.
- """
- pass
- def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the beginning of a training step. If using gradient accumulation, one training step might take
- several inputs.
- """
- pass
- def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
- """
- pass
- def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
- """
- pass
- def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the end of an substep during gradient accumulation.
- """
- pass
- def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called at the end of a training step. If using gradient accumulation, one training step might take
- several inputs.
- """
- pass
- def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called after an evaluation phase.
- """
- pass
- def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
- """
- Event called after a successful prediction.
- """
- pass
- def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called after a checkpoint save.
- """
- pass
- def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called after logging the last logs.
- """
- pass
- def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- """
- Event called after a prediction step.
- """
- pass
- class CallbackHandler(TrainerCallback):
- """Internal class that just calls the list of callbacks in order."""
- def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
- self.callbacks = []
- for cb in callbacks:
- self.add_callback(cb)
- self.model = model
- self.processing_class = processing_class
- self.optimizer = optimizer
- self.lr_scheduler = lr_scheduler
- self.train_dataloader = None
- self.eval_dataloader = None
- if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
- logger.warning(
- "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
- + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
- + "callbacks is\n:"
- + self.callback_list
- )
- def add_callback(self, callback):
- cb = callback() if isinstance(callback, type) else callback
- cb_class = callback if isinstance(callback, type) else callback.__class__
- if cb_class in [c.__class__ for c in self.callbacks]:
- logger.warning(
- f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
- + "list of callbacks is\n:"
- + self.callback_list
- )
- self.callbacks.append(cb)
- def pop_callback(self, callback):
- if isinstance(callback, type):
- for cb in self.callbacks:
- if isinstance(cb, callback):
- self.callbacks.remove(cb)
- return cb
- else:
- for cb in self.callbacks:
- if cb == callback:
- self.callbacks.remove(cb)
- return cb
- def remove_callback(self, callback):
- if isinstance(callback, type):
- for cb in self.callbacks:
- if isinstance(cb, callback):
- self.callbacks.remove(cb)
- return
- else:
- self.callbacks.remove(callback)
- @property
- def callback_list(self):
- return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
- def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_init_end", args, state, control)
- def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- control.should_training_stop = False
- return self.call_event("on_train_begin", args, state, control)
- def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_train_end", args, state, control)
- def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- control.should_epoch_stop = False
- return self.call_event("on_epoch_begin", args, state, control)
- def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_epoch_end", args, state, control)
- def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- control.should_log = False
- control.should_evaluate = False
- control.should_save = False
- return self.call_event("on_step_begin", args, state, control)
- def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_pre_optimizer_step", args, state, control)
- def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_optimizer_step", args, state, control)
- def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_substep_end", args, state, control)
- def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_step_end", args, state, control)
- def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
- control.should_evaluate = False
- return self.call_event("on_evaluate", args, state, control, metrics=metrics)
- def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
- return self.call_event("on_predict", args, state, control, metrics=metrics)
- def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- control.should_save = False
- return self.call_event("on_save", args, state, control)
- def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
- control.should_log = False
- return self.call_event("on_log", args, state, control, logs=logs)
- def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
- return self.call_event("on_prediction_step", args, state, control)
- def call_event(self, event, args, state, control, **kwargs):
- for callback in self.callbacks:
- result = getattr(callback, event)(
- args,
- state,
- control,
- model=self.model,
- processing_class=self.processing_class,
- optimizer=self.optimizer,
- lr_scheduler=self.lr_scheduler,
- train_dataloader=self.train_dataloader,
- eval_dataloader=self.eval_dataloader,
- **kwargs,
- )
- # A Callback can skip the return of `control` if it doesn't change it.
- if result is not None:
- control = result
- return control
- class DefaultFlowCallback(TrainerCallback):
- """
- A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
- """
- def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- # Log
- if state.global_step == 1 and args.logging_first_step:
- control.should_log = True
- if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
- control.should_log = True
- # Evaluate
- if (
- args.eval_strategy == IntervalStrategy.STEPS
- and state.global_step % state.eval_steps == 0
- and args.eval_delay <= state.global_step
- ):
- control.should_evaluate = True
- # Save
- if (
- args.save_strategy == IntervalStrategy.STEPS
- and state.save_steps > 0
- and state.global_step % state.save_steps == 0
- ):
- control.should_save = True
- # End training
- if state.global_step >= state.max_steps:
- control.should_training_stop = True
- # Save the model at the end if we have a save strategy
- if args.save_strategy != IntervalStrategy.NO:
- control.should_save = True
- return control
- def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
- # Log
- if args.logging_strategy == IntervalStrategy.EPOCH:
- control.should_log = True
- # Evaluate
- if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
- control.should_evaluate = True
- # Save
- if args.save_strategy == IntervalStrategy.EPOCH:
- control.should_save = True
- return control
- class ProgressCallback(TrainerCallback):
- """
- A [`TrainerCallback`] that displays the progress of training or evaluation.
- """
- def __init__(self):
- self.training_bar = None
- self.prediction_bar = None
- def on_train_begin(self, args, state, control, **kwargs):
- if state.is_world_process_zero:
- self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
- self.current_step = 0
- def on_step_end(self, args, state, control, **kwargs):
- if state.is_world_process_zero:
- self.training_bar.update(state.global_step - self.current_step)
- self.current_step = state.global_step
- def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
- if state.is_world_process_zero and has_length(eval_dataloader):
- if self.prediction_bar is None:
- self.prediction_bar = tqdm(
- total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
- )
- self.prediction_bar.update(1)
- def on_evaluate(self, args, state, control, **kwargs):
- if state.is_world_process_zero:
- if self.prediction_bar is not None:
- self.prediction_bar.close()
- self.prediction_bar = None
- def on_predict(self, args, state, control, **kwargs):
- if state.is_world_process_zero:
- if self.prediction_bar is not None:
- self.prediction_bar.close()
- self.prediction_bar = None
- def on_log(self, args, state, control, logs=None, **kwargs):
- if state.is_world_process_zero and self.training_bar is not None:
- # make a shallow copy of logs so we can mutate the fields copied
- # but avoid doing any value pickling.
- shallow_logs = {}
- for k, v in logs.items():
- shallow_logs[k] = v
- _ = shallow_logs.pop("total_flos", None)
- # round numbers so that it looks better in console
- if "epoch" in shallow_logs:
- shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)
- self.training_bar.write(str(shallow_logs))
- def on_train_end(self, args, state, control, **kwargs):
- if state.is_world_process_zero:
- self.training_bar.close()
- self.training_bar = None
- class PrinterCallback(TrainerCallback):
- """
- A bare [`TrainerCallback`] that just prints the logs.
- """
- def on_log(self, args, state, control, logs=None, **kwargs):
- _ = logs.pop("total_flos", None)
- if state.is_local_process_zero:
- print(logs)
- class EarlyStoppingCallback(TrainerCallback, ExportableState):
- """
- A [`TrainerCallback`] that handles early stopping.
- Args:
- early_stopping_patience (`int`):
- Use with `metric_for_best_model` to stop training when the specified metric worsens for
- `early_stopping_patience` evaluation calls.
- early_stopping_threshold(`float`, *optional*):
- Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
- specified metric must improve to satisfy early stopping conditions. `
- This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
- in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
- early stopping will not occur until the next save step.
- """
- def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
- self.early_stopping_patience = early_stopping_patience
- self.early_stopping_threshold = early_stopping_threshold
- # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
- self.early_stopping_patience_counter = 0
- def check_metric_value(self, args, state, control, metric_value):
- # best_metric is set by code for load_best_model
- operator = np.greater if args.greater_is_better else np.less
- if state.best_metric is None or (
- operator(metric_value, state.best_metric)
- and abs(metric_value - state.best_metric) > self.early_stopping_threshold
- ):
- self.early_stopping_patience_counter = 0
- else:
- self.early_stopping_patience_counter += 1
- def on_train_begin(self, args, state, control, **kwargs):
- assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
- assert (
- args.metric_for_best_model is not None
- ), "EarlyStoppingCallback requires metric_for_best_model is defined"
- assert (
- args.eval_strategy != IntervalStrategy.NO
- ), "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
- def on_evaluate(self, args, state, control, metrics, **kwargs):
- metric_to_check = args.metric_for_best_model
- if not metric_to_check.startswith("eval_"):
- metric_to_check = f"eval_{metric_to_check}"
- metric_value = metrics.get(metric_to_check)
- if metric_value is None:
- logger.warning(
- f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
- " is disabled"
- )
- return
- self.check_metric_value(args, state, control, metric_value)
- if self.early_stopping_patience_counter >= self.early_stopping_patience:
- control.should_training_stop = True
- def state(self) -> dict:
- return {
- "args": {
- "early_stopping_patience": self.early_stopping_patience,
- "early_stopping_threshold": self.early_stopping_threshold,
- },
- "attributes": {
- "early_stopping_patience_counter": self.early_stopping_patience_counter,
- },
- }
|