keras_callbacks.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import logging
  2. import os
  3. from pathlib import Path
  4. from time import sleep
  5. from typing import Callable, List, Optional, Union
  6. import numpy as np
  7. import tensorflow as tf
  8. from huggingface_hub import Repository, create_repo
  9. from packaging.version import parse
  10. from . import IntervalStrategy, PreTrainedTokenizerBase
  11. from .modelcard import TrainingSummary
  12. from .modeling_tf_utils import keras
  13. logger = logging.getLogger(__name__)
  14. class KerasMetricCallback(keras.callbacks.Callback):
  15. """
  16. Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
  17. compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
  18. operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
  19. `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
  20. metrics and return a dict mapping metric names to metric values.
  21. We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
  22. this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
  23. ```py
  24. from datasets import load_metric
  25. rouge_metric = load_metric("rouge")
  26. def rouge_fn(predictions, labels):
  27. decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  28. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
  29. result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
  30. return {key: value.mid.fmeasure * 100 for key, value in result.items()}
  31. ```
  32. The above function will return a dict containing values which will be logged like any other Keras metric:
  33. ```
  34. {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
  35. ```
  36. Args:
  37. metric_fn (`Callable`):
  38. Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
  39. These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
  40. metric names to numerical values.
  41. eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
  42. Validation data to be used to generate predictions for the `metric_fn`.
  43. output_cols (`List[str], *optional*):
  44. A list of columns to be retained from the model output as the predictions. Defaults to all.
  45. label_cols ('`List[str]`, *optional*'):
  46. A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
  47. supplied.
  48. batch_size (`int`, *optional*):
  49. Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
  50. predict_with_generate (`bool`, *optional*, defaults to `False`):
  51. Whether we should use `model.generate()` to get outputs for the model.
  52. use_xla_generation (`bool`, *optional*, defaults to `False`):
  53. If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
  54. generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
  55. generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
  56. argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
  57. save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
  58. generate_kwargs (`dict`, *optional*):
  59. Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
  60. is `False`.
  61. """
  62. def __init__(
  63. self,
  64. metric_fn: Callable,
  65. eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
  66. output_cols: Optional[List[str]] = None,
  67. label_cols: Optional[List[str]] = None,
  68. batch_size: Optional[int] = None,
  69. predict_with_generate: bool = False,
  70. use_xla_generation: bool = False,
  71. generate_kwargs: Optional[dict] = None,
  72. ):
  73. super().__init__()
  74. self.metric_fn = metric_fn
  75. self.batch_size = batch_size
  76. if not isinstance(eval_dataset, tf.data.Dataset):
  77. if batch_size is None:
  78. raise ValueError(
  79. "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
  80. "the batch_size argument must be set."
  81. )
  82. # Wrap a tf.data.Dataset around it
  83. eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
  84. self.eval_dataset = eval_dataset
  85. self.predict_with_generate = predict_with_generate
  86. self.output_cols = output_cols
  87. # This next block attempts to parse out which elements of the dataset should be appended to the labels list
  88. # that is passed to the metric_fn
  89. if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
  90. input_spec, label_spec = eval_dataset.element_spec
  91. else:
  92. input_spec = eval_dataset.element_spec
  93. label_spec = None
  94. if label_cols is not None:
  95. for label in label_cols:
  96. if label not in input_spec:
  97. raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
  98. self.label_cols = label_cols
  99. self.use_keras_label = False
  100. elif label_spec is not None:
  101. # If the dataset inputs are split into a 2-tuple of inputs and labels,
  102. # assume the second element is the labels
  103. self.label_cols = None
  104. self.use_keras_label = True
  105. elif "labels" in input_spec:
  106. self.label_cols = ["labels"]
  107. self.use_keras_label = False
  108. logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
  109. elif "start_positions" in input_spec and "end_positions" in input_spec:
  110. self.label_cols = ["start_positions", "end_positions"]
  111. self.use_keras_label = False
  112. logging.warning(
  113. "No label_cols specified for KerasMetricCallback, assuming you want the "
  114. "start_positions and end_positions keys."
  115. )
  116. else:
  117. raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
  118. if parse(tf.__version__) < parse("2.7"):
  119. logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
  120. self.use_xla_generation = use_xla_generation
  121. self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
  122. self.generation_function = None
  123. @staticmethod
  124. def _concatenate_batches(batches, padding_index=-100):
  125. # If all batches are unidimensional or same length, do a simple concatenation
  126. if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
  127. return np.concatenate(batches, axis=0)
  128. # Welp, they're not the same length. Let's do some padding
  129. max_len = max([batch.shape[1] for batch in batches])
  130. num_samples = sum([batch.shape[0] for batch in batches])
  131. output = np.full_like(
  132. batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
  133. )
  134. # i keeps track of which part of the concatenated array we're writing the next batch to
  135. i = 0
  136. for batch in batches:
  137. output[i : i + len(batch), : batch.shape[1]] = batch
  138. i += len(batch)
  139. return output
  140. def _postprocess_predictions_or_labels(self, inputs):
  141. if isinstance(inputs[0], dict):
  142. outputs = {}
  143. for key in inputs[0].keys():
  144. outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
  145. # If it's a dict with only one key, just return the array
  146. if len(outputs) == 1:
  147. outputs = list(outputs.values())[0]
  148. elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
  149. outputs = []
  150. for input_list in zip(*inputs):
  151. outputs.append(self._concatenate_batches(input_list))
  152. if len(outputs) == 1:
  153. outputs = outputs[0] # If it's a list with only one element, just return the array
  154. elif isinstance(inputs[0], np.ndarray):
  155. outputs = self._concatenate_batches(inputs)
  156. elif isinstance(inputs[0], tf.Tensor):
  157. outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
  158. else:
  159. raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
  160. return outputs
  161. def on_epoch_end(self, epoch, logs=None):
  162. if hasattr(self.model, "config"):
  163. ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
  164. else:
  165. ignore_keys = []
  166. main_input_name = None
  167. if self.predict_with_generate:
  168. # This dense conditional recognizes the case where we have an encoder-decoder model, but
  169. # avoids getting tangled up when we just have a model with a layer called 'encoder'
  170. if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
  171. main_input_name = self.model.encoder.main_input_name
  172. else:
  173. main_input_name = getattr(self.model, "main_input_name", "input_ids")
  174. if self.use_xla_generation and self.generation_function is None:
  175. def generation_function(inputs, attention_mask):
  176. return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
  177. self.generation_function = tf.function(generation_function, jit_compile=True)
  178. prediction_list = []
  179. label_list = []
  180. # The whole predict/generate loop is handled inside this method
  181. for batch in self.eval_dataset:
  182. if isinstance(batch, tuple):
  183. batch, labels = batch
  184. else:
  185. labels = None
  186. if self.predict_with_generate:
  187. if isinstance(batch, dict):
  188. generation_inputs = batch[main_input_name]
  189. attention_mask = batch.get("attention_mask", None)
  190. else:
  191. generation_inputs = batch
  192. attention_mask = None
  193. if self.use_xla_generation:
  194. predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
  195. else:
  196. predictions = self.model.generate(
  197. generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
  198. )
  199. else:
  200. predictions = self.model.predict_on_batch(batch)
  201. if isinstance(predictions, dict):
  202. # This converts any dict-subclass to a regular dict
  203. # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
  204. predictions = dict(predictions)
  205. if self.output_cols is not None:
  206. predictions = {key: predictions[key] for key in self.output_cols}
  207. else:
  208. predictions = {
  209. key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
  210. }
  211. prediction_list.append(predictions)
  212. if not self.use_keras_label:
  213. labels = {key: batch[key].numpy() for key in self.label_cols}
  214. elif isinstance(labels, dict):
  215. labels = {key: array.numpy() for key, array in labels.items()}
  216. elif isinstance(labels, list) or isinstance(labels, tuple):
  217. labels = [array.numpy() for array in labels]
  218. elif isinstance(labels, tf.Tensor):
  219. labels = labels.numpy()
  220. else:
  221. raise TypeError(f"Confused by labels of type {type(labels)}")
  222. label_list.append(labels)
  223. all_preds = self._postprocess_predictions_or_labels(prediction_list)
  224. all_labels = self._postprocess_predictions_or_labels(label_list)
  225. metric_output = self.metric_fn((all_preds, all_labels))
  226. if not isinstance(metric_output, dict):
  227. raise TypeError(
  228. f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
  229. )
  230. # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
  231. # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
  232. # new keys in there, which will then get read by the History callback and treated like any other metric value.
  233. # I promise that I have it in writing from Chollet that this is okay.
  234. logs.update(metric_output)
  235. class PushToHubCallback(keras.callbacks.Callback):
  236. """
  237. Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
  238. be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
  239. as with the `from_pretrained` method.
  240. ```py
  241. from transformers.keras_callbacks import PushToHubCallback
  242. push_to_hub_callback = PushToHubCallback(
  243. output_dir="./model_save",
  244. tokenizer=tokenizer,
  245. hub_model_id="gpt5-7xlarge",
  246. )
  247. model.fit(train_dataset, callbacks=[push_to_hub_callback])
  248. ```
  249. Args:
  250. output_dir (`str`):
  251. The output directory where the model predictions and checkpoints will be written and synced with the
  252. repository on the Hub.
  253. save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
  254. The checkpoint save strategy to adopt during training. Possible values are:
  255. - `"no"`: Save is done at the end of training.
  256. - `"epoch"`: Save is done at the end of each epoch.
  257. - `"steps"`: Save is done every `save_steps`
  258. save_steps (`int`, *optional*):
  259. The number of steps between saves when using the "steps" `save_strategy`.
  260. tokenizer (`PreTrainedTokenizerBase`, *optional*):
  261. The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
  262. hub_model_id (`str`, *optional*):
  263. The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
  264. which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
  265. for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
  266. `"organization_name/model"`.
  267. Will default to the name of `output_dir`.
  268. hub_token (`str`, *optional*):
  269. The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
  270. `huggingface-cli login`.
  271. checkpoint (`bool`, *optional*, defaults to `False`):
  272. Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
  273. resumed. Only usable when `save_strategy` is `"epoch"`.
  274. """
  275. def __init__(
  276. self,
  277. output_dir: Union[str, Path],
  278. save_strategy: Union[str, IntervalStrategy] = "epoch",
  279. save_steps: Optional[int] = None,
  280. tokenizer: Optional[PreTrainedTokenizerBase] = None,
  281. hub_model_id: Optional[str] = None,
  282. hub_token: Optional[str] = None,
  283. checkpoint: bool = False,
  284. **model_card_args,
  285. ):
  286. super().__init__()
  287. if checkpoint and save_strategy != "epoch":
  288. raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
  289. if isinstance(save_strategy, str):
  290. save_strategy = IntervalStrategy(save_strategy.lower())
  291. self.save_strategy = save_strategy
  292. if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
  293. raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
  294. self.save_steps = save_steps
  295. output_dir = Path(output_dir)
  296. # Create repo and retrieve repo_id
  297. if hub_model_id is None:
  298. hub_model_id = output_dir.absolute().name
  299. self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
  300. self.output_dir = output_dir
  301. self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
  302. self.tokenizer = tokenizer
  303. self.last_job = None
  304. self.checkpoint = checkpoint
  305. self.training_history = None
  306. self.model_card_args = model_card_args
  307. def on_train_begin(self, logs=None):
  308. # Although we can access model.history, we have no guarantees that the History callback will fire before this
  309. # one, so we keep track of it here too
  310. self.training_history = []
  311. def on_train_batch_end(self, batch, logs=None):
  312. if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
  313. if self.last_job is not None and not self.last_job.is_done:
  314. return # The last upload is still running, don't start another
  315. self.model.save_pretrained(self.output_dir)
  316. if self.tokenizer is not None:
  317. self.tokenizer.save_pretrained(self.output_dir)
  318. _, self.last_job = self.repo.push_to_hub(
  319. commit_message=f"Training in progress steps {batch}", blocking=False
  320. )
  321. def on_epoch_end(self, epoch, logs=None):
  322. logs = logs.copy() # Don't accidentally write things that Keras will read later
  323. if "epoch" not in logs:
  324. logs["epoch"] = epoch
  325. self.training_history.append(logs)
  326. if self.save_strategy == IntervalStrategy.EPOCH:
  327. if self.last_job is not None and not self.last_job.is_done:
  328. return # The last upload is still running, don't start another
  329. self.model.save_pretrained(self.output_dir)
  330. if self.tokenizer is not None:
  331. self.tokenizer.save_pretrained(self.output_dir)
  332. if self.checkpoint:
  333. checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
  334. self.model._save_checkpoint(checkpoint_dir, epoch)
  335. train_summary = TrainingSummary.from_keras(
  336. model=self.model,
  337. model_name=self.hub_model_id,
  338. keras_history=self.training_history,
  339. **self.model_card_args,
  340. )
  341. model_card = train_summary.to_model_card()
  342. with (self.output_dir / "README.md").open("w") as f:
  343. f.write(model_card)
  344. _, self.last_job = self.repo.push_to_hub(
  345. commit_message=f"Training in progress epoch {epoch}", blocking=False
  346. )
  347. def on_train_end(self, logs=None):
  348. # Makes sure the latest version of the model is uploaded
  349. if self.last_job is not None and not self.last_job.is_done:
  350. logging.info("Pushing the last epoch to the Hub, this may take a while...")
  351. while not self.last_job.is_done:
  352. sleep(1)
  353. else:
  354. self.model.save_pretrained(self.output_dir)
  355. if self.tokenizer is not None:
  356. self.tokenizer.save_pretrained(self.output_dir)
  357. train_summary = TrainingSummary.from_keras(
  358. model=self.model,
  359. model_name=self.hub_model_id,
  360. keras_history=self.training_history,
  361. **self.model_card_args,
  362. )
  363. model_card = train_summary.to_model_card()
  364. with (self.output_dir / "README.md").open("w") as f:
  365. f.write(model_card)
  366. self.repo.push_to_hub(commit_message="End of training", blocking=True)