trainer_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. # coding=utf-8
  2. # Copyright 2020-present the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. PyTorch-independent utilities for the Trainer class.
  17. """
  18. import copy
  19. import functools
  20. import gc
  21. import inspect
  22. import os
  23. import random
  24. import re
  25. import threading
  26. import time
  27. from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
  28. import numpy as np
  29. from .utils import (
  30. ExplicitEnum,
  31. is_psutil_available,
  32. is_tf_available,
  33. is_torch_available,
  34. is_torch_cuda_available,
  35. is_torch_mlu_available,
  36. is_torch_mps_available,
  37. is_torch_musa_available,
  38. is_torch_npu_available,
  39. is_torch_xla_available,
  40. is_torch_xpu_available,
  41. requires_backends,
  42. )
  43. if is_torch_available():
  44. import torch
  45. def seed_worker(_):
  46. """
  47. Helper function to set worker seed during Dataloader initialization.
  48. """
  49. worker_seed = torch.initial_seed() % 2**32
  50. set_seed(worker_seed)
  51. def enable_full_determinism(seed: int, warn_only: bool = False):
  52. """
  53. Helper function for reproducible behavior during distributed training. See
  54. - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
  55. - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
  56. """
  57. # set seed first
  58. set_seed(seed)
  59. if is_torch_available():
  60. # Enable PyTorch deterministic mode. This potentially requires either the environment
  61. # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
  62. # depending on the CUDA version, so we set them both here
  63. os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
  64. os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
  65. # The environment variable required to enable deterministic mode on Ascend NPUs.
  66. os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
  67. os.environ["HCCL_DETERMINISTIC"] = "1"
  68. os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
  69. torch.use_deterministic_algorithms(True, warn_only=warn_only)
  70. # Enable CUDNN deterministic mode
  71. torch.backends.cudnn.deterministic = True
  72. torch.backends.cudnn.benchmark = False
  73. if is_tf_available():
  74. import tensorflow as tf
  75. tf.config.experimental.enable_op_determinism()
  76. def set_seed(seed: int, deterministic: bool = False):
  77. """
  78. Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
  79. Args:
  80. seed (`int`):
  81. The seed to set.
  82. deterministic (`bool`, *optional*, defaults to `False`):
  83. Whether to use deterministic algorithms where available. Can slow down training.
  84. """
  85. random.seed(seed)
  86. np.random.seed(seed)
  87. if is_torch_available():
  88. torch.manual_seed(seed)
  89. torch.cuda.manual_seed_all(seed)
  90. # ^^ safe to call this function even if cuda is not available
  91. if deterministic:
  92. torch.use_deterministic_algorithms(True)
  93. if is_torch_mlu_available():
  94. torch.mlu.manual_seed_all(seed)
  95. if is_torch_musa_available():
  96. torch.musa.manual_seed_all(seed)
  97. if is_torch_npu_available():
  98. torch.npu.manual_seed_all(seed)
  99. if is_torch_xpu_available():
  100. torch.xpu.manual_seed_all(seed)
  101. if is_tf_available():
  102. import tensorflow as tf
  103. tf.random.set_seed(seed)
  104. if deterministic:
  105. tf.config.experimental.enable_op_determinism()
  106. def neftune_post_forward_hook(module, input, output):
  107. """
  108. Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
  109. layers. This method is slightly adapted from the original source code that can be found here:
  110. https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
  111. ```python
  112. model = ...
  113. model.embed_tokens.neftune_noise_alpha = 0.1
  114. model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
  115. ```
  116. Args:
  117. module (`torch.nn.Module`):
  118. The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
  119. the desired noise alpha value.
  120. input (`torch.Tensor`):
  121. The input tensor to the model.
  122. output (`torch.Tensor`):
  123. The output tensor of the model (i.e. the embeddings).
  124. """
  125. if module.training:
  126. dims = torch.tensor(output.size(1) * output.size(2))
  127. mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
  128. output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
  129. return output
  130. class EvalPrediction:
  131. """
  132. Evaluation output (always contains labels), to be used to compute metrics.
  133. Parameters:
  134. predictions (`np.ndarray`): Predictions of the model.
  135. label_ids (`np.ndarray`): Targets to be matched.
  136. inputs (`np.ndarray`, *optional*): Input data passed to the model.
  137. losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
  138. """
  139. def __init__(
  140. self,
  141. predictions: Union[np.ndarray, Tuple[np.ndarray]],
  142. label_ids: Union[np.ndarray, Tuple[np.ndarray]],
  143. inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
  144. losses: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
  145. ):
  146. self.predictions = predictions
  147. self.label_ids = label_ids
  148. self.inputs = inputs
  149. self.losses = losses
  150. self.elements = (self.predictions, self.label_ids)
  151. if self.inputs is not None:
  152. self.elements += (self.inputs,)
  153. if self.losses is not None:
  154. self.elements += (self.losses,)
  155. def __iter__(self):
  156. return iter(self.elements)
  157. def __getitem__(self, idx):
  158. if idx < 0 or idx >= len(self.elements):
  159. raise IndexError("tuple index out of range")
  160. return self.elements[idx]
  161. class EvalLoopOutput(NamedTuple):
  162. predictions: Union[np.ndarray, Tuple[np.ndarray]]
  163. label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
  164. metrics: Optional[Dict[str, float]]
  165. num_samples: Optional[int]
  166. class PredictionOutput(NamedTuple):
  167. predictions: Union[np.ndarray, Tuple[np.ndarray]]
  168. label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
  169. metrics: Optional[Dict[str, float]]
  170. class TrainOutput(NamedTuple):
  171. global_step: int
  172. training_loss: float
  173. metrics: Dict[str, float]
  174. PREFIX_CHECKPOINT_DIR = "checkpoint"
  175. _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
  176. def get_last_checkpoint(folder):
  177. content = os.listdir(folder)
  178. checkpoints = [
  179. path
  180. for path in content
  181. if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
  182. ]
  183. if len(checkpoints) == 0:
  184. return
  185. return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
  186. class IntervalStrategy(ExplicitEnum):
  187. NO = "no"
  188. STEPS = "steps"
  189. EPOCH = "epoch"
  190. class EvaluationStrategy(ExplicitEnum):
  191. NO = "no"
  192. STEPS = "steps"
  193. EPOCH = "epoch"
  194. class HubStrategy(ExplicitEnum):
  195. END = "end"
  196. EVERY_SAVE = "every_save"
  197. CHECKPOINT = "checkpoint"
  198. ALL_CHECKPOINTS = "all_checkpoints"
  199. class BestRun(NamedTuple):
  200. """
  201. The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
  202. Parameters:
  203. run_id (`str`):
  204. The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
  205. with run-{run_id}).
  206. objective (`float`):
  207. The objective that was obtained for this run.
  208. hyperparameters (`Dict[str, Any]`):
  209. The hyperparameters picked to get this run.
  210. run_summary (`Optional[Any]`):
  211. A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
  212. """
  213. run_id: str
  214. objective: Union[float, List[float]]
  215. hyperparameters: Dict[str, Any]
  216. run_summary: Optional[Any] = None
  217. def default_compute_objective(metrics: Dict[str, float]) -> float:
  218. """
  219. The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
  220. metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
  221. Args:
  222. metrics (`Dict[str, float]`): The metrics returned by the evaluate method.
  223. Return:
  224. `float`: The objective to minimize or maximize
  225. """
  226. metrics = copy.deepcopy(metrics)
  227. loss = metrics.pop("eval_loss", None)
  228. _ = metrics.pop("epoch", None)
  229. # Remove speed metrics
  230. speed_metrics = [
  231. m
  232. for m in metrics.keys()
  233. if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
  234. ]
  235. for sm in speed_metrics:
  236. _ = metrics.pop(sm, None)
  237. return loss if len(metrics) == 0 else sum(metrics.values())
  238. def default_hp_space_optuna(trial) -> Dict[str, float]:
  239. from .integrations import is_optuna_available
  240. assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
  241. return {
  242. "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
  243. "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
  244. "seed": trial.suggest_int("seed", 1, 40),
  245. "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
  246. }
  247. def default_hp_space_ray(trial) -> Dict[str, float]:
  248. from .integrations import is_ray_tune_available
  249. assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
  250. from ray import tune
  251. return {
  252. "learning_rate": tune.loguniform(1e-6, 1e-4),
  253. "num_train_epochs": tune.choice(list(range(1, 6))),
  254. "seed": tune.uniform(1, 40),
  255. "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
  256. }
  257. def default_hp_space_sigopt(trial):
  258. return [
  259. {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformamtion": "log"},
  260. {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
  261. {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
  262. {
  263. "categorical_values": ["4", "8", "16", "32", "64"],
  264. "name": "per_device_train_batch_size",
  265. "type": "categorical",
  266. },
  267. ]
  268. def default_hp_space_wandb(trial) -> Dict[str, float]:
  269. from .integrations import is_wandb_available
  270. if not is_wandb_available():
  271. raise ImportError("This function needs wandb installed: `pip install wandb`")
  272. return {
  273. "method": "random",
  274. "metric": {"name": "objective", "goal": "minimize"},
  275. "parameters": {
  276. "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
  277. "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
  278. "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
  279. "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
  280. },
  281. }
  282. class HPSearchBackend(ExplicitEnum):
  283. OPTUNA = "optuna"
  284. RAY = "ray"
  285. SIGOPT = "sigopt"
  286. WANDB = "wandb"
  287. def is_main_process(local_rank):
  288. """
  289. Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
  290. `local_rank`.
  291. """
  292. if is_torch_xla_available():
  293. import torch_xla.core.xla_model as xm
  294. return xm.get_ordinal() == 0
  295. return local_rank in [-1, 0]
  296. def total_processes_number(local_rank):
  297. """
  298. Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
  299. """
  300. if is_torch_xla_available():
  301. import torch_xla.core.xla_model as xm
  302. return xm.xrt_world_size()
  303. elif local_rank != -1 and is_torch_available():
  304. import torch
  305. return torch.distributed.get_world_size()
  306. return 1
  307. def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
  308. """
  309. Measure and return speed performance metrics.
  310. This function requires a time snapshot `start_time` before the operation to be measured starts and this function
  311. should be run immediately after the operation to be measured has completed.
  312. Args:
  313. - split: name to prefix metric (like train, eval, test...)
  314. - start_time: operation start time
  315. - num_samples: number of samples processed
  316. - num_steps: number of steps processed
  317. - num_tokens: number of tokens processed
  318. """
  319. runtime = time.time() - start_time
  320. result = {f"{split}_runtime": round(runtime, 4)}
  321. if runtime == 0:
  322. return result
  323. if num_samples is not None:
  324. samples_per_second = num_samples / runtime
  325. result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
  326. if num_steps is not None:
  327. steps_per_second = num_steps / runtime
  328. result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
  329. if num_tokens is not None:
  330. tokens_per_second = num_tokens / runtime
  331. result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
  332. return result
  333. class SchedulerType(ExplicitEnum):
  334. """
  335. Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
  336. By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
  337. Scheduler types:
  338. - "linear" = get_linear_schedule_with_warmup
  339. - "cosine" = get_cosine_schedule_with_warmup
  340. - "cosine_with_restarts" = get_cosine_with_hard_restarts_schedule_with_warmup
  341. - "polynomial" = get_polynomial_decay_schedule_with_warmup
  342. - "constant" = get_constant_schedule
  343. - "constant_with_warmup" = get_constant_schedule_with_warmup
  344. - "inverse_sqrt" = get_inverse_sqrt_schedule
  345. - "reduce_lr_on_plateau" = get_reduce_on_plateau_schedule
  346. - "cosine_with_min_lr" = get_cosine_with_min_lr_schedule_with_warmup
  347. - "warmup_stable_decay" = get_wsd_schedule
  348. """
  349. LINEAR = "linear"
  350. COSINE = "cosine"
  351. COSINE_WITH_RESTARTS = "cosine_with_restarts"
  352. POLYNOMIAL = "polynomial"
  353. CONSTANT = "constant"
  354. CONSTANT_WITH_WARMUP = "constant_with_warmup"
  355. INVERSE_SQRT = "inverse_sqrt"
  356. REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
  357. COSINE_WITH_MIN_LR = "cosine_with_min_lr"
  358. WARMUP_STABLE_DECAY = "warmup_stable_decay"
  359. class TrainerMemoryTracker:
  360. """
  361. A helper class that tracks cpu and gpu memory.
  362. This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
  363. When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
  364. Example :
  365. ```python
  366. self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
  367. self._memory_tracker.start()
  368. # code ...
  369. metrics = {"train_runtime": 10.5}
  370. self._memory_tracker.stop_and_update_metrics(metrics)
  371. ```
  372. At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`.
  373. To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
  374. """
  375. # map trainer methods to metrics prefix
  376. stages = {
  377. "__init__": "init",
  378. "train": "train",
  379. "_inner_training_loop": "train",
  380. "evaluate": "eval",
  381. "predict": "test",
  382. }
  383. def __init__(self, skip_memory_metrics=False):
  384. self.skip_memory_metrics = skip_memory_metrics
  385. if not is_psutil_available():
  386. # soft dependency on psutil
  387. self.skip_memory_metrics = True
  388. if self.skip_memory_metrics:
  389. return
  390. import psutil # noqa
  391. if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
  392. import torch
  393. self.torch = torch
  394. self.gpu = {}
  395. elif is_torch_mps_available():
  396. import torch
  397. self.torch = torch
  398. self.gpu = {}
  399. elif is_torch_xpu_available():
  400. import torch
  401. self.torch = torch
  402. self.gpu = {}
  403. elif is_torch_npu_available():
  404. import torch
  405. self.torch = torch
  406. self.gpu = {}
  407. else:
  408. self.torch = None
  409. self.process = psutil.Process()
  410. self.cur_stage = None
  411. self.cpu = {}
  412. self.init_reported = False
  413. def derive_stage(self):
  414. """derives the stage/caller name automatically"""
  415. caller = inspect.currentframe().f_back.f_back.f_code.co_name
  416. if caller in self.stages:
  417. return self.stages[caller]
  418. else:
  419. raise ValueError(
  420. f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
  421. )
  422. def cpu_mem_used(self):
  423. """get resident set size memory for the current process"""
  424. return self.process.memory_info().rss
  425. def peak_monitor_func(self):
  426. self.cpu_mem_used_peak = -1
  427. while True:
  428. self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
  429. # can't sleep or will not catch the peak right (this comment is here on purpose)
  430. # time.sleep(0.001) # 1msec
  431. if not self.peak_monitoring:
  432. break
  433. def start(self):
  434. """start tracking for the caller's stage"""
  435. if self.skip_memory_metrics:
  436. return
  437. stage = self.derive_stage()
  438. # deal with nested calls of eval during train - simply ignore those
  439. if self.cur_stage is not None and self.cur_stage != stage:
  440. return
  441. self.cur_stage = stage
  442. gc.collect()
  443. if self.torch is not None:
  444. if torch.cuda.is_available():
  445. self.torch.cuda.reset_peak_memory_stats()
  446. self.torch.cuda.empty_cache()
  447. elif is_torch_mlu_available():
  448. self.torch.mlu.reset_peak_memory_stats()
  449. self.torch.mlu.empty_cache()
  450. elif is_torch_musa_available():
  451. self.torch.musa.reset_peak_memory_stats()
  452. self.torch.musa.empty_cache()
  453. elif is_torch_xpu_available():
  454. self.torch.xpu.reset_peak_memory_stats()
  455. self.torch.xpu.empty_cache()
  456. elif is_torch_npu_available():
  457. self.torch.npu.reset_peak_memory_stats()
  458. self.torch.npu.empty_cache()
  459. elif is_torch_mps_available():
  460. self.torch.mps.empty_cache()
  461. # gpu
  462. if self.torch is not None:
  463. if torch.cuda.is_available():
  464. self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
  465. elif is_torch_mlu_available():
  466. self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
  467. elif is_torch_musa_available():
  468. self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
  469. elif is_torch_xpu_available():
  470. self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
  471. elif is_torch_npu_available():
  472. self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
  473. elif is_torch_mps_available():
  474. self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
  475. # cpu
  476. self.cpu_mem_used_at_start = self.cpu_mem_used()
  477. self.peak_monitoring = True
  478. peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
  479. peak_monitor_thread.daemon = True
  480. peak_monitor_thread.start()
  481. def stop(self, stage):
  482. """stop tracking for the passed stage"""
  483. # deal with nested calls of eval during train - simply ignore those
  484. if self.cur_stage is not None and self.cur_stage != stage:
  485. return
  486. # this sends a signal to peak_monitor_func to complete its loop
  487. self.peak_monitoring = False
  488. # first ensure all objects get collected and their memory is freed
  489. gc.collect()
  490. if self.torch is not None:
  491. if torch.cuda.is_available():
  492. self.torch.cuda.empty_cache()
  493. elif is_torch_mlu_available():
  494. self.torch.mlu.empty_cache()
  495. elif is_torch_musa_available():
  496. self.torch.musa.empty_cache()
  497. elif is_torch_xpu_available():
  498. self.torch.xpu.empty_cache()
  499. elif is_torch_npu_available():
  500. self.torch.npu.empty_cache()
  501. elif is_torch_mps_available():
  502. self.torch.mps.empty_cache()
  503. # concepts:
  504. # - alloc_delta: the difference of allocated memory between the end and the start
  505. # - peaked_delta: the difference between the peak memory and the current memory
  506. # in order to know how much memory the measured code consumed one needs to sum these two
  507. # gpu
  508. if self.torch is not None:
  509. if torch.cuda.is_available():
  510. self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
  511. self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
  512. elif is_torch_mlu_available():
  513. self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
  514. self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
  515. elif is_torch_musa_available():
  516. self.gpu_mem_used_now = self.torch.musa.memory_allocated()
  517. self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
  518. elif is_torch_xpu_available():
  519. self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
  520. self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
  521. elif is_torch_npu_available():
  522. self.gpu_mem_used_now = self.torch.npu.memory_allocated()
  523. self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
  524. elif is_torch_mps_available():
  525. self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
  526. # self.torch.mps.max_memory_allocated() does not exist yet
  527. self.gpu_mem_used_peak = None
  528. else:
  529. raise ValueError("No available GPU device found!")
  530. self.gpu[self.cur_stage] = {
  531. "begin": self.gpu_mem_used_at_start,
  532. "end": self.gpu_mem_used_now,
  533. "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
  534. }
  535. if self.gpu_mem_used_peak is not None:
  536. self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
  537. else:
  538. self.gpu[self.cur_stage]["peaked"] = "Not available"
  539. # cpu
  540. self.cpu_mem_used_now = self.cpu_mem_used()
  541. self.cpu[self.cur_stage] = {
  542. "begin": self.cpu_mem_used_at_start,
  543. "end": self.cpu_mem_used_now,
  544. "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
  545. "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
  546. }
  547. # reset - cycle finished
  548. self.cur_stage = None
  549. def update_metrics(self, stage, metrics):
  550. """updates the metrics"""
  551. if self.skip_memory_metrics:
  552. return
  553. # deal with nested calls of eval during train - simply ignore those
  554. if self.cur_stage is not None and self.cur_stage != stage:
  555. return
  556. # since we don't have a way to return init metrics, we push them into the first of train/val/predict
  557. stages = [stage]
  558. if not self.init_reported:
  559. stages.insert(0, "init")
  560. self.init_reported = True
  561. for stage in stages:
  562. for t in ["alloc", "peaked"]:
  563. if stage in self.cpu and t in self.cpu[stage]:
  564. metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
  565. if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  566. metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
  567. # if we need additional debug info, enable the following
  568. # for t in ["begin", "end"]:
  569. # if stage in self.cpu and t in self.cpu[stage]:
  570. # metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
  571. # if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
  572. # metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
  573. # since memory can be allocated before init, and it might be difficult to track overall
  574. # memory usage, in particular for GPU, let's report memory usage at the point init was called
  575. if stages[0] == "init":
  576. metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
  577. if self.torch is not None:
  578. metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
  579. # if we also wanted to report any additional memory allocations in between init and
  580. # whatever the next stage was we could also report this:
  581. # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
  582. # metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
  583. # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
  584. # metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
  585. def stop_and_update_metrics(self, metrics=None):
  586. """combine stop and metrics update in one call for simpler code"""
  587. if self.skip_memory_metrics:
  588. return
  589. stage = self.derive_stage()
  590. self.stop(stage)
  591. # init doesn't have metrics to update so we just save that data for later stages to retrieve
  592. if metrics is not None:
  593. self.update_metrics(stage, metrics)
  594. def has_length(dataset):
  595. """
  596. Checks if the dataset implements __len__() and it doesn't raise an error
  597. """
  598. try:
  599. return len(dataset) is not None
  600. except TypeError:
  601. # TypeError: len() of unsized object
  602. return False
  603. def denumpify_detensorize(metrics):
  604. """
  605. Recursively calls `.item()` on the element of the dictionary passed
  606. """
  607. if isinstance(metrics, (list, tuple)):
  608. return type(metrics)(denumpify_detensorize(m) for m in metrics)
  609. elif isinstance(metrics, dict):
  610. return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
  611. elif isinstance(metrics, np.generic):
  612. return metrics.item()
  613. elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
  614. return metrics.item()
  615. return metrics
  616. def number_of_arguments(func):
  617. """
  618. Return the number of arguments of the passed function, even if it's a partial function.
  619. """
  620. if isinstance(func, functools.partial):
  621. total_args = len(inspect.signature(func.func).parameters)
  622. return total_args - len(func.args) - len(func.keywords)
  623. return len(inspect.signature(func).parameters)
  624. def find_executable_batch_size(
  625. function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
  626. ):
  627. """
  628. Args:
  629. A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
  630. CUDNN, the batch size is cut in half and passed to `function`. `function` must take in a `batch_size` parameter as
  631. its first argument.
  632. function (`callable`, *optional*)
  633. A function to wrap
  634. starting_batch_size (`int`, *optional*)
  635. The batch size to try and fit into memory
  636. auto_find_batch_size (`bool`, *optional*)
  637. If False, will just execute `function`
  638. """
  639. if function is None:
  640. return functools.partial(
  641. find_executable_batch_size,
  642. starting_batch_size=starting_batch_size,
  643. auto_find_batch_size=auto_find_batch_size,
  644. )
  645. if auto_find_batch_size:
  646. requires_backends(find_executable_batch_size, "accelerate")
  647. from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
  648. return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
  649. return functools.partial(function, batch_size=starting_batch_size)
  650. class FSDPOption(ExplicitEnum):
  651. FULL_SHARD = "full_shard"
  652. SHARD_GRAD_OP = "shard_grad_op"
  653. NO_SHARD = "no_shard"
  654. HYBRID_SHARD = "hybrid_shard"
  655. HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
  656. OFFLOAD = "offload"
  657. AUTO_WRAP = "auto_wrap"
  658. class RemoveColumnsCollator:
  659. """Wrap the data collator to remove unused columns before they are passed to the collator."""
  660. def __init__(
  661. self,
  662. data_collator,
  663. signature_columns,
  664. logger=None,
  665. model_name: Optional[str] = None,
  666. description: Optional[str] = None,
  667. ):
  668. self.data_collator = data_collator
  669. self.signature_columns = signature_columns
  670. self.logger = logger
  671. self.description = description
  672. self.model_name = model_name
  673. self.message_logged = False
  674. def _remove_columns(self, feature: dict) -> dict:
  675. if not isinstance(feature, dict):
  676. return feature
  677. if not self.message_logged and self.logger and self.model_name:
  678. ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
  679. if len(ignored_columns) > 0:
  680. dset_description = "" if self.description is None else f"in the {self.description} set"
  681. self.logger.info(
  682. f"The following columns {dset_description} don't have a corresponding argument in "
  683. f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
  684. f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
  685. " you can safely ignore this message."
  686. )
  687. self.message_logged = True
  688. return {k: v for k, v in feature.items() if k in self.signature_columns}
  689. def __call__(self, features: List[dict]):
  690. features = [self._remove_columns(feature) for feature in features]
  691. return self.data_collator(features)
  692. def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
  693. """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
  694. Args:
  695. optim_target_modules (`Union[str, List[str]]`):
  696. A list of strings to try to match. Can be also a full string.
  697. key (`str`):
  698. A key to search any matches in optim_target_modules
  699. return_is_regex (`bool`):
  700. If set to `True`, the method will return whether the passed `optim_target_modules`
  701. is a regex or not.
  702. Returns:
  703. `bool` : True of match object if key matches any target modules from config, False or
  704. None if no match found
  705. `bool` : If the matched target module is a regex to silence out the warnings in Trainer
  706. for extra modules being found (only if `target_module_found=True` for an array of regex).
  707. """
  708. target_module_found = False
  709. is_regex = False
  710. if isinstance(optim_target_modules, str):
  711. target_module_found = bool(re.fullmatch(optim_target_modules, key))
  712. is_regex = True if not optim_target_modules == key else False
  713. elif key in optim_target_modules: # from here, target_module_found must be a list of str
  714. # this module is specified directly in target_modules
  715. target_module_found = True
  716. elif any(target_key in key for target_key in optim_target_modules):
  717. target_module_found = True
  718. elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
  719. target_module_found = True
  720. is_regex = True
  721. if return_is_regex:
  722. return target_module_found, is_regex
  723. return target_module_found