integration_utils.py 94 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Integrations with other Python libraries.
  16. """
  17. import functools
  18. import importlib.metadata
  19. import importlib.util
  20. import json
  21. import numbers
  22. import os
  23. import pickle
  24. import shutil
  25. import sys
  26. import tempfile
  27. from dataclasses import asdict, fields
  28. from enum import Enum
  29. from pathlib import Path
  30. from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
  31. import numpy as np
  32. import packaging.version
  33. from .. import PreTrainedModel, TFPreTrainedModel
  34. from .. import __version__ as version
  35. from ..utils import (
  36. PushToHubMixin,
  37. flatten_dict,
  38. is_datasets_available,
  39. is_pandas_available,
  40. is_tf_available,
  41. is_torch_available,
  42. logging,
  43. )
  44. logger = logging.get_logger(__name__)
  45. if is_torch_available():
  46. import torch
  47. # comet_ml requires to be imported before any ML frameworks
  48. _MIN_COMET_VERSION = "3.43.2"
  49. try:
  50. _comet_version = importlib.metadata.version("comet_ml")
  51. _is_comet_installed = True
  52. _is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
  53. # Check if the Comet API Key is set
  54. import comet_ml
  55. if comet_ml.config.get_config("comet.api_key") is not None:
  56. _is_comet_configured = True
  57. else:
  58. _is_comet_configured = False
  59. except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
  60. _comet_version = None
  61. _is_comet_installed = False
  62. _is_comet_recent_enough = False
  63. _is_comet_configured = False
  64. _has_neptune = (
  65. importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
  66. )
  67. if TYPE_CHECKING and _has_neptune:
  68. try:
  69. _neptune_version = importlib.metadata.version("neptune")
  70. logger.info(f"Neptune version {_neptune_version} available.")
  71. except importlib.metadata.PackageNotFoundError:
  72. try:
  73. _neptune_version = importlib.metadata.version("neptune-client")
  74. logger.info(f"Neptune-client version {_neptune_version} available.")
  75. except importlib.metadata.PackageNotFoundError:
  76. _has_neptune = False
  77. from .. import modelcard # noqa: E402
  78. from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
  79. from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
  80. from ..training_args import ParallelMode # noqa: E402
  81. from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available # noqa: E402
  82. # Integration functions:
  83. def is_wandb_available():
  84. # any value of WANDB_DISABLED disables wandb
  85. if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
  86. logger.warning(
  87. "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
  88. "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
  89. )
  90. return False
  91. return importlib.util.find_spec("wandb") is not None
  92. def is_clearml_available():
  93. return importlib.util.find_spec("clearml") is not None
  94. def is_comet_available():
  95. if os.getenv("COMET_MODE", "").upper() == "DISABLED":
  96. logger.warning(
  97. "Using the `COMET_MODE=DISABLED` environment variable is deprecated and will be removed in v5. Use the "
  98. "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
  99. )
  100. return False
  101. if _is_comet_installed is False:
  102. return False
  103. if _is_comet_recent_enough is False:
  104. logger.warning(
  105. "comet_ml version %s is installed, but version %s or higher is required. "
  106. "Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
  107. _comet_version,
  108. _MIN_COMET_VERSION,
  109. _MIN_COMET_VERSION,
  110. )
  111. return False
  112. if _is_comet_configured is False:
  113. logger.warning(
  114. "comet_ml is installed but the Comet API Key is not configured. "
  115. "Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
  116. "Check out the documentation for other ways of configuring it: "
  117. "https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
  118. )
  119. return False
  120. return True
  121. def is_tensorboard_available():
  122. return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
  123. def is_optuna_available():
  124. return importlib.util.find_spec("optuna") is not None
  125. def is_ray_available():
  126. return importlib.util.find_spec("ray") is not None
  127. def is_ray_tune_available():
  128. if not is_ray_available():
  129. return False
  130. return importlib.util.find_spec("ray.tune") is not None
  131. def is_sigopt_available():
  132. return importlib.util.find_spec("sigopt") is not None
  133. def is_azureml_available():
  134. if importlib.util.find_spec("azureml") is None:
  135. return False
  136. if importlib.util.find_spec("azureml.core") is None:
  137. return False
  138. return importlib.util.find_spec("azureml.core.run") is not None
  139. def is_mlflow_available():
  140. if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
  141. return False
  142. return importlib.util.find_spec("mlflow") is not None
  143. def is_dagshub_available():
  144. return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
  145. def is_neptune_available():
  146. return _has_neptune
  147. def is_codecarbon_available():
  148. return importlib.util.find_spec("codecarbon") is not None
  149. def is_flytekit_available():
  150. return importlib.util.find_spec("flytekit") is not None
  151. def is_flyte_deck_standard_available():
  152. if not is_flytekit_available():
  153. return False
  154. return importlib.util.find_spec("flytekitplugins.deck") is not None
  155. def is_dvclive_available():
  156. return importlib.util.find_spec("dvclive") is not None
  157. def hp_params(trial):
  158. if is_optuna_available():
  159. import optuna
  160. if isinstance(trial, optuna.Trial):
  161. return trial.params
  162. if is_ray_tune_available():
  163. if isinstance(trial, dict):
  164. return trial
  165. if is_sigopt_available():
  166. if isinstance(trial, dict):
  167. return trial
  168. if is_wandb_available():
  169. if isinstance(trial, dict):
  170. return trial
  171. raise RuntimeError(f"Unknown type for trial {trial.__class__}")
  172. def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  173. import optuna
  174. if trainer.args.process_index == 0:
  175. def _objective(trial, checkpoint_dir=None):
  176. checkpoint = None
  177. if checkpoint_dir:
  178. for subdir in os.listdir(checkpoint_dir):
  179. if subdir.startswith(PREFIX_CHECKPOINT_DIR):
  180. checkpoint = os.path.join(checkpoint_dir, subdir)
  181. trainer.objective = None
  182. if trainer.args.world_size > 1:
  183. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  184. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  185. trainer._hp_search_setup(trial)
  186. args_main_rank_list = [pickle.dumps(trainer.args)]
  187. torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
  188. trainer.train(resume_from_checkpoint=checkpoint)
  189. else:
  190. trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
  191. # If there hasn't been any evaluation during the training loop.
  192. if getattr(trainer, "objective", None) is None:
  193. metrics = trainer.evaluate()
  194. trainer.objective = trainer.compute_objective(metrics)
  195. return trainer.objective
  196. timeout = kwargs.pop("timeout", None)
  197. n_jobs = kwargs.pop("n_jobs", 1)
  198. gc_after_trial = kwargs.pop("gc_after_trial", False)
  199. directions = direction if isinstance(direction, list) else None
  200. direction = None if directions is not None else direction
  201. study = optuna.create_study(direction=direction, directions=directions, **kwargs)
  202. study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, gc_after_trial=gc_after_trial)
  203. if not study._is_multi_objective():
  204. best_trial = study.best_trial
  205. return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
  206. else:
  207. best_trials = study.best_trials
  208. return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
  209. else:
  210. for i in range(n_trials):
  211. trainer.objective = None
  212. args_main_rank_list = [None]
  213. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  214. raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
  215. torch.distributed.broadcast_object_list(args_main_rank_list, src=0)
  216. args = pickle.loads(bytes(args_main_rank_list[0]))
  217. for key, value in asdict(args).items():
  218. if key != "local_rank":
  219. setattr(trainer.args, key, value)
  220. trainer.train(resume_from_checkpoint=None)
  221. # If there hasn't been any evaluation during the training loop.
  222. if getattr(trainer, "objective", None) is None:
  223. metrics = trainer.evaluate()
  224. trainer.objective = trainer.compute_objective(metrics)
  225. return None
  226. def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  227. import ray
  228. import ray.train
  229. def _objective(trial: dict, local_trainer):
  230. try:
  231. from transformers.utils.notebook import NotebookProgressCallback
  232. if local_trainer.pop_callback(NotebookProgressCallback):
  233. local_trainer.add_callback(ProgressCallback)
  234. except ModuleNotFoundError:
  235. pass
  236. local_trainer.objective = None
  237. checkpoint = ray.train.get_checkpoint()
  238. if checkpoint:
  239. # Upon trial resume, the local_trainer's objective gets reset to None.
  240. # If `local_trainer.train` is a noop (training has already reached
  241. # the target number of epochs/steps), then this would
  242. # trigger an unnecessary extra checkpoint at the end of training.
  243. # -> Set the objective to a dummy value upon resume as a workaround.
  244. local_trainer.objective = "objective"
  245. with checkpoint.as_directory() as checkpoint_dir:
  246. checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
  247. local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
  248. else:
  249. local_trainer.train(trial=trial)
  250. # If there hasn't been any evaluation during the training loop.
  251. if getattr(local_trainer, "objective", None) is None:
  252. metrics = local_trainer.evaluate()
  253. local_trainer.objective = local_trainer.compute_objective(metrics)
  254. metrics.update({"objective": local_trainer.objective, "done": True})
  255. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  256. local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
  257. checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
  258. ray.train.report(metrics, checkpoint=checkpoint)
  259. if not trainer._memory_tracker.skip_memory_metrics:
  260. from ..trainer_utils import TrainerMemoryTracker
  261. logger.warning(
  262. "Memory tracking for your Trainer is currently "
  263. "enabled. Automatically disabling the memory tracker "
  264. "since the memory tracker is not serializable."
  265. )
  266. trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
  267. # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
  268. # while doing the ray hp search.
  269. _tb_writer = trainer.pop_callback(TensorBoardCallback)
  270. trainer.model = None
  271. # Setup default `resources_per_trial`.
  272. if "resources_per_trial" not in kwargs:
  273. # Default to 1 CPU and 1 GPU (if applicable) per trial.
  274. kwargs["resources_per_trial"] = {"cpu": 1}
  275. if trainer.args.n_gpu > 0:
  276. kwargs["resources_per_trial"]["gpu"] = 1
  277. resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
  278. logger.info(
  279. "No `resources_per_trial` arg was passed into "
  280. "`hyperparameter_search`. Setting it to a default value "
  281. f"of {resource_msg} for each trial."
  282. )
  283. # Make sure each trainer only uses GPUs that were allocated per trial.
  284. gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
  285. trainer.args._n_gpu = gpus_per_trial
  286. # Setup default `progress_reporter`.
  287. if "progress_reporter" not in kwargs:
  288. from ray.tune import CLIReporter
  289. kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
  290. if "scheduler" in kwargs:
  291. from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
  292. # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
  293. if isinstance(
  294. kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
  295. ) and (not trainer.args.do_eval or trainer.args.eval_strategy == IntervalStrategy.NO):
  296. raise RuntimeError(
  297. "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
  298. "This means your trials will not report intermediate results to Ray Tune, and "
  299. "can thus not be stopped early or used to exploit other trials parameters. "
  300. "If this is what you want, do not use {cls}. If you would like to use {cls}, "
  301. "make sure you pass `do_eval=True` and `eval_strategy='steps'` in the "
  302. "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
  303. )
  304. trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
  305. @functools.wraps(trainable)
  306. def dynamic_modules_import_trainable(*args, **kwargs):
  307. """
  308. Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
  309. Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
  310. Assumes that `_objective`, defined above, is a function.
  311. """
  312. if is_datasets_available():
  313. import datasets.load
  314. dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
  315. # load dynamic_modules from path
  316. spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
  317. datasets_modules = importlib.util.module_from_spec(spec)
  318. sys.modules[spec.name] = datasets_modules
  319. spec.loader.exec_module(datasets_modules)
  320. return trainable(*args, **kwargs)
  321. # special attr set by tune.with_parameters
  322. if hasattr(trainable, "__mixins__"):
  323. dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
  324. analysis = ray.tune.run(
  325. dynamic_modules_import_trainable,
  326. config=trainer.hp_space(None),
  327. num_samples=n_trials,
  328. **kwargs,
  329. )
  330. best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
  331. best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
  332. if _tb_writer is not None:
  333. trainer.add_callback(_tb_writer)
  334. return best_run
  335. def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  336. import sigopt
  337. if trainer.args.process_index == 0:
  338. if importlib.metadata.version("sigopt") >= "8.0.0":
  339. sigopt.set_project("huggingface")
  340. experiment = sigopt.create_experiment(
  341. name="huggingface-tune",
  342. type="offline",
  343. parameters=trainer.hp_space(None),
  344. metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
  345. parallel_bandwidth=1,
  346. budget=n_trials,
  347. )
  348. logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
  349. for run in experiment.loop():
  350. with run:
  351. trainer.objective = None
  352. if trainer.args.world_size > 1:
  353. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  354. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  355. trainer._hp_search_setup(run.run)
  356. torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
  357. trainer.train(resume_from_checkpoint=None)
  358. else:
  359. trainer.train(resume_from_checkpoint=None, trial=run.run)
  360. # If there hasn't been any evaluation during the training loop.
  361. if getattr(trainer, "objective", None) is None:
  362. metrics = trainer.evaluate()
  363. trainer.objective = trainer.compute_objective(metrics)
  364. run.log_metric("objective", trainer.objective)
  365. best = list(experiment.get_best_runs())[0]
  366. best_run = BestRun(best.id, best.values["objective"].value, best.assignments)
  367. else:
  368. from sigopt import Connection
  369. conn = Connection()
  370. proxies = kwargs.pop("proxies", None)
  371. if proxies is not None:
  372. conn.set_proxies(proxies)
  373. experiment = conn.experiments().create(
  374. name="huggingface-tune",
  375. parameters=trainer.hp_space(None),
  376. metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
  377. parallel_bandwidth=1,
  378. observation_budget=n_trials,
  379. project="huggingface",
  380. )
  381. logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")
  382. while experiment.progress.observation_count < experiment.observation_budget:
  383. suggestion = conn.experiments(experiment.id).suggestions().create()
  384. trainer.objective = None
  385. if trainer.args.world_size > 1:
  386. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  387. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  388. trainer._hp_search_setup(suggestion)
  389. torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
  390. trainer.train(resume_from_checkpoint=None)
  391. else:
  392. trainer.train(resume_from_checkpoint=None, trial=suggestion)
  393. # If there hasn't been any evaluation during the training loop.
  394. if getattr(trainer, "objective", None) is None:
  395. metrics = trainer.evaluate()
  396. trainer.objective = trainer.compute_objective(metrics)
  397. values = [{"name": "objective", "value": trainer.objective}]
  398. obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
  399. logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
  400. experiment = conn.experiments(experiment.id).fetch()
  401. best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
  402. best_run = BestRun(best.id, best.value, best.assignments)
  403. return best_run
  404. else:
  405. for i in range(n_trials):
  406. trainer.objective = None
  407. args_main_rank = list(pickle.dumps(trainer.args))
  408. if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
  409. raise RuntimeError("only support DDP Sigopt HPO for ParallelMode.DISTRIBUTED currently.")
  410. torch.distributed.broadcast_object_list(args_main_rank, src=0)
  411. args = pickle.loads(bytes(args_main_rank))
  412. for key, value in asdict(args).items():
  413. if key != "local_rank":
  414. setattr(trainer.args, key, value)
  415. trainer.train(resume_from_checkpoint=None)
  416. # If there hasn't been any evaluation during the training loop.
  417. if getattr(trainer, "objective", None) is None:
  418. metrics = trainer.evaluate()
  419. trainer.objective = trainer.compute_objective(metrics)
  420. return None
  421. def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
  422. from ..integrations import is_wandb_available
  423. if not is_wandb_available():
  424. raise ImportError("This function needs wandb installed: `pip install wandb`")
  425. import wandb
  426. # add WandbCallback if not already added in trainer callbacks
  427. reporting_to_wandb = False
  428. for callback in trainer.callback_handler.callbacks:
  429. if isinstance(callback, WandbCallback):
  430. reporting_to_wandb = True
  431. break
  432. if not reporting_to_wandb:
  433. trainer.add_callback(WandbCallback())
  434. trainer.args.report_to = ["wandb"]
  435. best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
  436. sweep_id = kwargs.pop("sweep_id", None)
  437. project = kwargs.pop("project", None)
  438. name = kwargs.pop("name", None)
  439. entity = kwargs.pop("entity", None)
  440. metric = kwargs.pop("metric", "eval/loss")
  441. sweep_config = trainer.hp_space(None)
  442. sweep_config["metric"]["goal"] = direction
  443. sweep_config["metric"]["name"] = metric
  444. if name:
  445. sweep_config["name"] = name
  446. def _objective():
  447. run = wandb.run if wandb.run else wandb.init()
  448. trainer.state.trial_name = run.name
  449. run.config.update({"assignments": {}, "metric": metric})
  450. config = wandb.config
  451. trainer.objective = None
  452. trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
  453. # If there hasn't been any evaluation during the training loop.
  454. if getattr(trainer, "objective", None) is None:
  455. metrics = trainer.evaluate()
  456. trainer.objective = trainer.compute_objective(metrics)
  457. format_metrics = rewrite_logs(metrics)
  458. if metric not in format_metrics:
  459. logger.warning(
  460. f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
  461. f" metrics are {format_metrics.keys()}"
  462. )
  463. best_score = False
  464. if best_trial["run_id"] is not None:
  465. if direction == "minimize":
  466. best_score = trainer.objective < best_trial["objective"]
  467. elif direction == "maximize":
  468. best_score = trainer.objective > best_trial["objective"]
  469. if best_score or best_trial["run_id"] is None:
  470. best_trial["run_id"] = run.id
  471. best_trial["objective"] = trainer.objective
  472. best_trial["hyperparameters"] = dict(config)
  473. return trainer.objective
  474. sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
  475. logger.info(f"wandb sweep id - {sweep_id}")
  476. wandb.agent(sweep_id, function=_objective, count=n_trials)
  477. return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
  478. def get_available_reporting_integrations():
  479. integrations = []
  480. if is_azureml_available() and not is_mlflow_available():
  481. integrations.append("azure_ml")
  482. if is_comet_available():
  483. integrations.append("comet_ml")
  484. if is_dagshub_available():
  485. integrations.append("dagshub")
  486. if is_dvclive_available():
  487. integrations.append("dvclive")
  488. if is_mlflow_available():
  489. integrations.append("mlflow")
  490. if is_neptune_available():
  491. integrations.append("neptune")
  492. if is_tensorboard_available():
  493. integrations.append("tensorboard")
  494. if is_wandb_available():
  495. integrations.append("wandb")
  496. if is_codecarbon_available():
  497. integrations.append("codecarbon")
  498. if is_clearml_available():
  499. integrations.append("clearml")
  500. return integrations
  501. def rewrite_logs(d):
  502. new_d = {}
  503. eval_prefix = "eval_"
  504. eval_prefix_len = len(eval_prefix)
  505. test_prefix = "test_"
  506. test_prefix_len = len(test_prefix)
  507. for k, v in d.items():
  508. if k.startswith(eval_prefix):
  509. new_d["eval/" + k[eval_prefix_len:]] = v
  510. elif k.startswith(test_prefix):
  511. new_d["test/" + k[test_prefix_len:]] = v
  512. else:
  513. new_d["train/" + k] = v
  514. return new_d
  515. class TensorBoardCallback(TrainerCallback):
  516. """
  517. A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
  518. Args:
  519. tb_writer (`SummaryWriter`, *optional*):
  520. The writer to use. Will instantiate one if not set.
  521. """
  522. def __init__(self, tb_writer=None):
  523. has_tensorboard = is_tensorboard_available()
  524. if not has_tensorboard:
  525. raise RuntimeError(
  526. "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
  527. " install tensorboardX."
  528. )
  529. if has_tensorboard:
  530. try:
  531. from torch.utils.tensorboard import SummaryWriter # noqa: F401
  532. self._SummaryWriter = SummaryWriter
  533. except ImportError:
  534. try:
  535. from tensorboardX import SummaryWriter
  536. self._SummaryWriter = SummaryWriter
  537. except ImportError:
  538. self._SummaryWriter = None
  539. else:
  540. self._SummaryWriter = None
  541. self.tb_writer = tb_writer
  542. def _init_summary_writer(self, args, log_dir=None):
  543. log_dir = log_dir or args.logging_dir
  544. if self._SummaryWriter is not None:
  545. self.tb_writer = self._SummaryWriter(log_dir=log_dir)
  546. def on_train_begin(self, args, state, control, **kwargs):
  547. if not state.is_world_process_zero:
  548. return
  549. log_dir = None
  550. if state.is_hyper_param_search:
  551. trial_name = state.trial_name
  552. if trial_name is not None:
  553. log_dir = os.path.join(args.logging_dir, trial_name)
  554. if self.tb_writer is None:
  555. self._init_summary_writer(args, log_dir)
  556. if self.tb_writer is not None:
  557. self.tb_writer.add_text("args", args.to_json_string())
  558. if "model" in kwargs:
  559. model = kwargs["model"]
  560. if hasattr(model, "config") and model.config is not None:
  561. model_config_json = model.config.to_json_string()
  562. self.tb_writer.add_text("model_config", model_config_json)
  563. def on_log(self, args, state, control, logs=None, **kwargs):
  564. if not state.is_world_process_zero:
  565. return
  566. if self.tb_writer is None:
  567. self._init_summary_writer(args)
  568. if self.tb_writer is not None:
  569. logs = rewrite_logs(logs)
  570. for k, v in logs.items():
  571. if isinstance(v, (int, float)):
  572. self.tb_writer.add_scalar(k, v, state.global_step)
  573. else:
  574. logger.warning(
  575. "Trainer is attempting to log a value of "
  576. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  577. "This invocation of Tensorboard's writer.add_scalar() "
  578. "is incorrect so we dropped this attribute."
  579. )
  580. self.tb_writer.flush()
  581. def on_train_end(self, args, state, control, **kwargs):
  582. if self.tb_writer:
  583. self.tb_writer.close()
  584. self.tb_writer = None
  585. def save_model_architecture_to_file(model: Any, output_dir: str):
  586. with open(f"{output_dir}/model_architecture.txt", "w+") as f:
  587. if isinstance(model, PreTrainedModel):
  588. print(model, file=f)
  589. elif is_tf_available() and isinstance(model, TFPreTrainedModel):
  590. def print_to_file(s):
  591. print(s, file=f)
  592. model.summary(print_fn=print_to_file)
  593. elif is_torch_available() and (
  594. isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
  595. ):
  596. print(model, file=f)
  597. class WandbLogModel(str, Enum):
  598. """Enum of possible log model values in W&B."""
  599. CHECKPOINT = "checkpoint"
  600. END = "end"
  601. FALSE = "false"
  602. @property
  603. def is_enabled(self) -> bool:
  604. """Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
  605. return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
  606. @classmethod
  607. def _missing_(cls, value: Any) -> "WandbLogModel":
  608. if not isinstance(value, str):
  609. raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
  610. if value.upper() in ENV_VARS_TRUE_VALUES:
  611. raise DeprecationWarning(
  612. f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
  613. "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
  614. )
  615. logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
  616. return WandbLogModel.END
  617. logger.warning(
  618. f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
  619. )
  620. return WandbLogModel.FALSE
  621. class WandbCallback(TrainerCallback):
  622. """
  623. A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
  624. """
  625. def __init__(self):
  626. has_wandb = is_wandb_available()
  627. if not has_wandb:
  628. raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
  629. if has_wandb:
  630. import wandb
  631. self._wandb = wandb
  632. self._initialized = False
  633. self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
  634. def setup(self, args, state, model, **kwargs):
  635. """
  636. Setup the optional Weights & Biases (*wandb*) integration.
  637. One can subclass and override this method to customize the setup if needed. Find more information
  638. [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
  639. variables:
  640. Environment:
  641. - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
  642. Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
  643. to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
  644. will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
  645. with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
  646. <Deprecated version="5.0">
  647. Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
  648. </Deprecated>
  649. - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
  650. Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
  651. parameters.
  652. - **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
  653. Set this to a custom string to store results in a different project.
  654. - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
  655. Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
  656. """
  657. if self._wandb is None:
  658. return
  659. self._initialized = True
  660. # prepare to handle potential configuration issues during setup
  661. from wandb.sdk.lib.config_util import ConfigError as WandbConfigError
  662. if state.is_world_process_zero:
  663. logger.info(
  664. 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
  665. )
  666. combined_dict = {**args.to_dict()}
  667. if hasattr(model, "config") and model.config is not None:
  668. model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
  669. combined_dict = {**model_config, **combined_dict}
  670. if hasattr(model, "peft_config") and model.peft_config is not None:
  671. peft_config = model.peft_config
  672. combined_dict = {**{"peft_config": peft_config}, **combined_dict}
  673. trial_name = state.trial_name
  674. init_args = {}
  675. if trial_name is not None:
  676. init_args["name"] = trial_name
  677. init_args["group"] = args.run_name
  678. elif args.run_name is not None:
  679. init_args["name"] = args.run_name
  680. if args.run_name == args.output_dir:
  681. self._wandb.termwarn(
  682. "The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was "
  683. "not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.",
  684. repeat=False,
  685. )
  686. if self._wandb.run is None:
  687. self._wandb.init(
  688. project=os.getenv("WANDB_PROJECT", "huggingface"),
  689. **init_args,
  690. )
  691. # add config parameters (run may have been created manually)
  692. self._wandb.config.update(combined_dict, allow_val_change=True)
  693. # define default x-axis (for latest wandb versions)
  694. if getattr(self._wandb, "define_metric", None):
  695. self._wandb.define_metric("train/global_step")
  696. self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
  697. # keep track of model topology and gradients, unsupported on TPU
  698. _watch_model = os.getenv("WANDB_WATCH", "false")
  699. if not is_torch_xla_available() and _watch_model in ("all", "parameters", "gradients"):
  700. self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
  701. self._wandb.run._label(code="transformers_trainer")
  702. # add number of model parameters to wandb config
  703. try:
  704. self._wandb.config["model/num_parameters"] = model.num_parameters()
  705. except AttributeError:
  706. logger.info(
  707. "Could not log the number of model parameters in Weights & Biases due to an AttributeError."
  708. )
  709. except WandbConfigError:
  710. logger.warning(
  711. "A ConfigError was raised whilst setting the number of model parameters in Weights & Biases config."
  712. )
  713. # log the initial model architecture to an artifact
  714. if self._log_model.is_enabled:
  715. with tempfile.TemporaryDirectory() as temp_dir:
  716. model_name = (
  717. f"model-{self._wandb.run.id}"
  718. if (args.run_name is None or args.run_name == args.output_dir)
  719. else f"model-{self._wandb.run.name}"
  720. )
  721. model_artifact = self._wandb.Artifact(
  722. name=model_name,
  723. type="model",
  724. metadata={
  725. "model_config": model.config.to_dict() if hasattr(model, "config") else None,
  726. "num_parameters": self._wandb.config.get("model/num_parameters"),
  727. "initial_model": True,
  728. },
  729. )
  730. # add the architecture to a separate text file
  731. save_model_architecture_to_file(model, temp_dir)
  732. for f in Path(temp_dir).glob("*"):
  733. if f.is_file():
  734. with model_artifact.new_file(f.name, mode="wb") as fa:
  735. fa.write(f.read_bytes())
  736. self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
  737. badge_markdown = (
  738. f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
  739. f'-28.svg" alt="Visualize in Weights & Biases" width="20'
  740. f'0" height="32"/>]({self._wandb.run.get_url()})'
  741. )
  742. modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
  743. def on_train_begin(self, args, state, control, model=None, **kwargs):
  744. if self._wandb is None:
  745. return
  746. hp_search = state.is_hyper_param_search
  747. if hp_search:
  748. self._wandb.finish()
  749. self._initialized = False
  750. args.run_name = None
  751. if not self._initialized:
  752. self.setup(args, state, model, **kwargs)
  753. def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
  754. if self._wandb is None:
  755. return
  756. if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
  757. from ..trainer import Trainer
  758. fake_trainer = Trainer(args=args, model=model, processing_class=tokenizer)
  759. with tempfile.TemporaryDirectory() as temp_dir:
  760. fake_trainer.save_model(temp_dir)
  761. metadata = (
  762. {
  763. k: v
  764. for k, v in dict(self._wandb.summary).items()
  765. if isinstance(v, numbers.Number) and not k.startswith("_")
  766. }
  767. if not args.load_best_model_at_end
  768. else {
  769. f"eval/{args.metric_for_best_model}": state.best_metric,
  770. "train/total_floss": state.total_flos,
  771. "model/num_parameters": self._wandb.config.get("model/num_parameters"),
  772. }
  773. )
  774. metadata["final_model"] = True
  775. logger.info("Logging model artifacts. ...")
  776. model_name = (
  777. f"model-{self._wandb.run.id}"
  778. if (args.run_name is None or args.run_name == args.output_dir)
  779. else f"model-{self._wandb.run.name}"
  780. )
  781. # add the model architecture to a separate text file
  782. save_model_architecture_to_file(model, temp_dir)
  783. artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
  784. for f in Path(temp_dir).glob("*"):
  785. if f.is_file():
  786. with artifact.new_file(f.name, mode="wb") as fa:
  787. fa.write(f.read_bytes())
  788. self._wandb.run.log_artifact(artifact, aliases=["final_model"])
  789. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  790. single_value_scalars = [
  791. "train_runtime",
  792. "train_samples_per_second",
  793. "train_steps_per_second",
  794. "train_loss",
  795. "total_flos",
  796. ]
  797. if self._wandb is None:
  798. return
  799. if not self._initialized:
  800. self.setup(args, state, model)
  801. if state.is_world_process_zero:
  802. for k, v in logs.items():
  803. if k in single_value_scalars:
  804. self._wandb.run.summary[k] = v
  805. non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
  806. non_scalar_logs = rewrite_logs(non_scalar_logs)
  807. self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
  808. def on_save(self, args, state, control, **kwargs):
  809. if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
  810. checkpoint_metadata = {
  811. k: v
  812. for k, v in dict(self._wandb.summary).items()
  813. if isinstance(v, numbers.Number) and not k.startswith("_")
  814. }
  815. checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
  816. ckpt_dir = f"checkpoint-{state.global_step}"
  817. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  818. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
  819. checkpoint_name = (
  820. f"model-{self._wandb.run.id}"
  821. if (args.run_name is None or args.run_name == args.output_dir)
  822. else f"model-{self._wandb.run.name}"
  823. )
  824. artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
  825. artifact.add_dir(artifact_path)
  826. self._wandb.log_artifact(
  827. artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
  828. )
  829. def on_predict(self, args, state, control, metrics, **kwargs):
  830. if self._wandb is None:
  831. return
  832. if not self._initialized:
  833. self.setup(args, state, **kwargs)
  834. if state.is_world_process_zero:
  835. metrics = rewrite_logs(metrics)
  836. self._wandb.log(metrics)
  837. class CometCallback(TrainerCallback):
  838. """
  839. A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.com/site/).
  840. """
  841. def __init__(self):
  842. if _is_comet_installed is False or _is_comet_recent_enough is False:
  843. raise RuntimeError(
  844. f"CometCallback requires comet-ml>={_MIN_COMET_VERSION} to be installed. Run `pip install comet-ml>={_MIN_COMET_VERSION}`."
  845. )
  846. self._initialized = False
  847. self._log_assets = False
  848. self._experiment = None
  849. def setup(self, args, state, model):
  850. """
  851. Setup the optional Comet integration.
  852. Environment:
  853. - **COMET_MODE** (`str`, *optional*, default to `get_or_create`):
  854. Control whether to create and log to a new Comet experiment or append to an existing experiment.
  855. It accepts the following values:
  856. * `get_or_create`: Decides automatically depending if
  857. `COMET_EXPERIMENT_KEY` is set and whether an Experiment
  858. with that key already exists or not.
  859. * `create`: Always create a new Comet Experiment.
  860. * `get`: Always try to append to an Existing Comet Experiment.
  861. Requires `COMET_EXPERIMENT_KEY` to be set.
  862. * `ONLINE`: **deprecated**, used to create an online
  863. Experiment. Use `COMET_START_ONLINE=1` instead.
  864. * `OFFLINE`: **deprecated**, used to created an offline
  865. Experiment. Use `COMET_START_ONLINE=0` instead.
  866. * `DISABLED`: **deprecated**, used to disable Comet logging.
  867. Use the `--report_to` flag to control the integrations used
  868. for logging result instead.
  869. - **COMET_PROJECT_NAME** (`str`, *optional*):
  870. Comet project name for experiments.
  871. - **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
  872. Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or
  873. `FALSE`.
  874. For a number of configurable items in the environment, see
  875. [here](https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options).
  876. """
  877. self._initialized = True
  878. log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
  879. if log_assets in {"TRUE", "1"}:
  880. self._log_assets = True
  881. if state.is_world_process_zero:
  882. comet_old_mode = os.getenv("COMET_MODE")
  883. mode = None
  884. online = None
  885. if comet_old_mode is not None:
  886. comet_old_mode = comet_old_mode.lower()
  887. if comet_old_mode == "online":
  888. online = True
  889. elif comet_old_mode == "offline":
  890. online = False
  891. elif comet_old_mode in ("get", "get_or_create", "create"):
  892. mode = comet_old_mode
  893. elif comet_old_mode:
  894. logger.warning("Invalid COMET_MODE env value %r, Comet logging is disabled", comet_old_mode)
  895. return
  896. # For HPO, we always create a new experiment for each trial
  897. if state.is_hyper_param_search:
  898. if mode is not None:
  899. logger.warning(
  900. "Hyperparameter Search is enabled, forcing the creation of new experimetns, COMET_MODE value %r is ignored",
  901. comet_old_mode,
  902. )
  903. mode = "create"
  904. import comet_ml
  905. # Do not use the default run_name as the experiment name
  906. if args.run_name is not None and args.run_name != args.output_dir:
  907. experiment_config = comet_ml.ExperimentConfig(name=args.run_name)
  908. else:
  909. experiment_config = comet_ml.ExperimentConfig()
  910. self._experiment = comet_ml.start(online=online, mode=mode, experiment_config=experiment_config)
  911. self._experiment.__internal_api__set_model_graph__(model, framework="transformers")
  912. params = {"args": args.to_dict()}
  913. if hasattr(model, "config") and model.config is not None:
  914. model_config = model.config.to_dict()
  915. params["config"] = model_config
  916. if hasattr(model, "peft_config") and model.peft_config is not None:
  917. peft_config = model.peft_config
  918. params["peft_config"] = peft_config
  919. self._experiment.__internal_api__log_parameters__(
  920. params, framework="transformers", source="manual", flatten_nested=True
  921. )
  922. if state.is_hyper_param_search:
  923. optimization_id = getattr(state, "trial_name", None)
  924. optimization_params = getattr(state, "trial_params", None)
  925. self._experiment.log_optimization(optimization_id=optimization_id, parameters=optimization_params)
  926. def on_train_begin(self, args, state, control, model=None, **kwargs):
  927. if not self._initialized:
  928. self.setup(args, state, model)
  929. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  930. if not self._initialized:
  931. self.setup(args, state, model)
  932. if state.is_world_process_zero:
  933. if self._experiment is not None:
  934. rewritten_logs = rewrite_logs(logs)
  935. self._experiment.__internal_api__log_metrics__(
  936. rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers"
  937. )
  938. def on_train_end(self, args, state, control, **kwargs):
  939. if self._initialized and state.is_world_process_zero:
  940. if self._experiment is not None:
  941. if self._log_assets is True:
  942. logger.info("Logging checkpoints. This may take time.")
  943. self._experiment.log_asset_folder(
  944. args.output_dir, recursive=True, log_file_name=True, step=state.global_step
  945. )
  946. # We create one experiment per trial in HPO mode
  947. if state.is_hyper_param_search:
  948. self._experiment.clean()
  949. self._initialized = False
  950. def on_predict(self, args, state, control, metrics, **kwargs):
  951. if not self._initialized:
  952. self.setup(args, state, model=None)
  953. if state.is_world_process_zero and self._experiment is not None:
  954. rewritten_metrics = rewrite_logs(metrics)
  955. self._experiment.__internal_api__log_metrics__(
  956. rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers"
  957. )
  958. class AzureMLCallback(TrainerCallback):
  959. """
  960. A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
  961. """
  962. def __init__(self, azureml_run=None):
  963. if not is_azureml_available():
  964. raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
  965. self.azureml_run = azureml_run
  966. def on_init_end(self, args, state, control, **kwargs):
  967. from azureml.core.run import Run
  968. if self.azureml_run is None and state.is_world_process_zero:
  969. self.azureml_run = Run.get_context()
  970. def on_log(self, args, state, control, logs=None, **kwargs):
  971. if self.azureml_run and state.is_world_process_zero:
  972. for k, v in logs.items():
  973. if isinstance(v, (int, float)):
  974. self.azureml_run.log(k, v, description=k)
  975. class MLflowCallback(TrainerCallback):
  976. """
  977. A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
  978. environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
  979. """
  980. def __init__(self):
  981. if not is_mlflow_available():
  982. raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
  983. import mlflow
  984. self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
  985. self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
  986. self._initialized = False
  987. self._auto_end_run = False
  988. self._log_artifacts = False
  989. self._ml_flow = mlflow
  990. def setup(self, args, state, model):
  991. """
  992. Setup the optional MLflow integration.
  993. Environment:
  994. - **HF_MLFLOW_LOG_ARTIFACTS** (`str`, *optional*):
  995. Whether to use MLflow `.log_artifact()` facility to log artifacts. This only makes sense if logging to a
  996. remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy each saved checkpoint on each save in
  997. [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
  998. storage will just copy the files to your artifact location.
  999. - **MLFLOW_TRACKING_URI** (`str`, *optional*):
  1000. Whether to store runs at a specific path or remote server. Unset by default, which skips setting the
  1001. tracking URI entirely.
  1002. - **MLFLOW_EXPERIMENT_NAME** (`str`, *optional*, defaults to `None`):
  1003. Whether to use an MLflow experiment_name under which to launch the run. Default to `None` which will point
  1004. to the `Default` experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be
  1005. activated. If an experiment with this name does not exist, a new experiment with this name is created.
  1006. - **MLFLOW_TAGS** (`str`, *optional*):
  1007. A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
  1008. `os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'`.
  1009. - **MLFLOW_NESTED_RUN** (`str`, *optional*):
  1010. Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
  1011. run.
  1012. - **MLFLOW_RUN_ID** (`str`, *optional*):
  1013. Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint. When
  1014. `MLFLOW_RUN_ID` environment variable is set, `start_run` attempts to resume a run with the specified run ID
  1015. and other parameters are ignored.
  1016. - **MLFLOW_FLATTEN_PARAMS** (`str`, *optional*, defaults to `False`):
  1017. Whether to flatten the parameters dictionary before logging.
  1018. - **MLFLOW_MAX_LOG_PARAMS** (`int`, *optional*):
  1019. Set the maximum number of parameters to log in the run.
  1020. """
  1021. self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1022. self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1023. self._tracking_uri = os.getenv("MLFLOW_TRACKING_URI", None)
  1024. self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
  1025. self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1026. self._run_id = os.getenv("MLFLOW_RUN_ID", None)
  1027. self._max_log_params = os.getenv("MLFLOW_MAX_LOG_PARAMS", None)
  1028. # "synchronous" flag is only available with mlflow version >= 2.8.0
  1029. # https://github.com/mlflow/mlflow/pull/9705
  1030. # https://github.com/mlflow/mlflow/releases/tag/v2.8.0
  1031. self._async_log = packaging.version.parse(self._ml_flow.__version__) >= packaging.version.parse("2.8.0")
  1032. logger.debug(
  1033. f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
  1034. f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
  1035. )
  1036. if state.is_world_process_zero:
  1037. if not self._ml_flow.is_tracking_uri_set():
  1038. if self._tracking_uri:
  1039. self._ml_flow.set_tracking_uri(self._tracking_uri)
  1040. logger.debug(f"MLflow tracking URI is set to {self._tracking_uri}")
  1041. else:
  1042. logger.debug(
  1043. "Environment variable `MLFLOW_TRACKING_URI` is not provided and therefore will not be"
  1044. " explicitly set."
  1045. )
  1046. else:
  1047. logger.debug(f"MLflow tracking URI is set to {self._ml_flow.get_tracking_uri()}")
  1048. if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
  1049. if self._experiment_name:
  1050. # Use of set_experiment() ensure that Experiment is created if not exists
  1051. self._ml_flow.set_experiment(self._experiment_name)
  1052. self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
  1053. logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
  1054. self._auto_end_run = True
  1055. combined_dict = args.to_dict()
  1056. if hasattr(model, "config") and model.config is not None:
  1057. model_config = model.config.to_dict()
  1058. combined_dict = {**model_config, **combined_dict}
  1059. combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
  1060. # remove params that are too long for MLflow
  1061. for name, value in list(combined_dict.items()):
  1062. # internally, all values are converted to str in MLflow
  1063. if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
  1064. logger.warning(
  1065. f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
  1066. " log_param() only accepts values no longer than 250 characters so we dropped this attribute."
  1067. " You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and"
  1068. " avoid this message."
  1069. )
  1070. del combined_dict[name]
  1071. # MLflow cannot log more than 100 values in one go, so we have to split it
  1072. combined_dict_items = list(combined_dict.items())
  1073. if self._max_log_params and self._max_log_params.isdigit():
  1074. max_log_params = int(self._max_log_params)
  1075. if max_log_params < len(combined_dict_items):
  1076. logger.debug(
  1077. f"Reducing the number of parameters to log from {len(combined_dict_items)} to {max_log_params}."
  1078. )
  1079. combined_dict_items = combined_dict_items[:max_log_params]
  1080. for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
  1081. if self._async_log:
  1082. self._ml_flow.log_params(
  1083. dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]), synchronous=False
  1084. )
  1085. else:
  1086. self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
  1087. mlflow_tags = os.getenv("MLFLOW_TAGS", None)
  1088. if mlflow_tags:
  1089. mlflow_tags = json.loads(mlflow_tags)
  1090. self._ml_flow.set_tags(mlflow_tags)
  1091. self._initialized = True
  1092. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1093. if not self._initialized:
  1094. self.setup(args, state, model)
  1095. def on_log(self, args, state, control, logs, model=None, **kwargs):
  1096. if not self._initialized:
  1097. self.setup(args, state, model)
  1098. if state.is_world_process_zero:
  1099. metrics = {}
  1100. for k, v in logs.items():
  1101. if isinstance(v, (int, float)):
  1102. metrics[k] = v
  1103. elif isinstance(v, torch.Tensor) and v.numel() == 1:
  1104. metrics[k] = v.item()
  1105. else:
  1106. logger.warning(
  1107. f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
  1108. "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
  1109. )
  1110. if self._async_log:
  1111. self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
  1112. else:
  1113. self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
  1114. def on_train_end(self, args, state, control, **kwargs):
  1115. if self._initialized and state.is_world_process_zero:
  1116. if self._auto_end_run and self._ml_flow.active_run():
  1117. self._ml_flow.end_run()
  1118. def on_save(self, args, state, control, **kwargs):
  1119. if self._initialized and state.is_world_process_zero and self._log_artifacts:
  1120. ckpt_dir = f"checkpoint-{state.global_step}"
  1121. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1122. logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
  1123. self._ml_flow.pyfunc.log_model(
  1124. ckpt_dir,
  1125. artifacts={"model_path": artifact_path},
  1126. python_model=self._ml_flow.pyfunc.PythonModel(),
  1127. )
  1128. def __del__(self):
  1129. # if the previous run is not terminated correctly, the fluent API will
  1130. # not let you start a new run before the previous one is killed
  1131. if (
  1132. self._auto_end_run
  1133. and callable(getattr(self._ml_flow, "active_run", None))
  1134. and self._ml_flow.active_run() is not None
  1135. ):
  1136. self._ml_flow.end_run()
  1137. class DagsHubCallback(MLflowCallback):
  1138. """
  1139. A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). Extends [`MLflowCallback`]
  1140. """
  1141. def __init__(self):
  1142. super().__init__()
  1143. if not is_dagshub_available():
  1144. raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
  1145. from dagshub.upload import Repo
  1146. self.Repo = Repo
  1147. def setup(self, *args, **kwargs):
  1148. """
  1149. Setup the DagsHub's Logging integration.
  1150. Environment:
  1151. - **HF_DAGSHUB_LOG_ARTIFACTS** (`str`, *optional*):
  1152. Whether to save the data and model artifacts for the experiment. Default to `False`.
  1153. """
  1154. self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
  1155. self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
  1156. self.remote = os.getenv("MLFLOW_TRACKING_URI")
  1157. self.repo = self.Repo(
  1158. owner=self.remote.split(os.sep)[-2],
  1159. name=self.remote.split(os.sep)[-1].split(".")[0],
  1160. branch=os.getenv("BRANCH") or "main",
  1161. )
  1162. self.path = Path("artifacts")
  1163. if self.remote is None:
  1164. raise RuntimeError(
  1165. "DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
  1166. " `dagshub.init()`?"
  1167. )
  1168. super().setup(*args, **kwargs)
  1169. def on_train_end(self, args, state, control, **kwargs):
  1170. if self.log_artifacts:
  1171. if getattr(self, "train_dataloader", None):
  1172. torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
  1173. self.repo.directory(str(self.path)).add_dir(args.output_dir)
  1174. class NeptuneMissingConfiguration(Exception):
  1175. def __init__(self):
  1176. super().__init__(
  1177. """
  1178. ------ Unsupported ---- We were not able to create new runs. You provided a custom Neptune run to
  1179. `NeptuneCallback` with the `run` argument. For the integration to work fully, provide your `api_token` and
  1180. `project` by saving them as environment variables or passing them to the callback.
  1181. """
  1182. )
  1183. class NeptuneCallback(TrainerCallback):
  1184. """TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
  1185. Args:
  1186. api_token (`str`, *optional*): Neptune API token obtained upon registration.
  1187. You can leave this argument out if you have saved your token to the `NEPTUNE_API_TOKEN` environment
  1188. variable (strongly recommended). See full setup instructions in the
  1189. [docs](https://docs.neptune.ai/setup/installation).
  1190. project (`str`, *optional*): Name of an existing Neptune project, in the form "workspace-name/project-name".
  1191. You can find and copy the name in Neptune from the project settings -> Properties. If None (default), the
  1192. value of the `NEPTUNE_PROJECT` environment variable is used.
  1193. name (`str`, *optional*): Custom name for the run.
  1194. base_namespace (`str`, *optional*, defaults to "finetuning"): In the Neptune run, the root namespace
  1195. that will contain all of the metadata logged by the callback.
  1196. log_parameters (`bool`, *optional*, defaults to `True`):
  1197. If True, logs all Trainer arguments and model parameters provided by the Trainer.
  1198. log_checkpoints (`str`, *optional*): If "same", uploads checkpoints whenever they are saved by the Trainer.
  1199. If "last", uploads only the most recently saved checkpoint. If "best", uploads the best checkpoint (among
  1200. the ones saved by the Trainer). If `None`, does not upload checkpoints.
  1201. run (`Run`, *optional*): Pass a Neptune run object if you want to continue logging to an existing run.
  1202. Read more about resuming runs in the [docs](https://docs.neptune.ai/logging/to_existing_object).
  1203. **neptune_run_kwargs (*optional*):
  1204. Additional keyword arguments to be passed directly to the
  1205. [`neptune.init_run()`](https://docs.neptune.ai/api/neptune#init_run) function when a new run is created.
  1206. For instructions and examples, see the [Transformers integration
  1207. guide](https://docs.neptune.ai/integrations/transformers) in the Neptune documentation.
  1208. """
  1209. integration_version_key = "source_code/integrations/transformers"
  1210. model_parameters_key = "model_parameters"
  1211. trial_name_key = "trial"
  1212. trial_params_key = "trial_params"
  1213. trainer_parameters_key = "trainer_parameters"
  1214. flat_metrics = {"train/epoch"}
  1215. def __init__(
  1216. self,
  1217. *,
  1218. api_token: Optional[str] = None,
  1219. project: Optional[str] = None,
  1220. name: Optional[str] = None,
  1221. base_namespace: str = "finetuning",
  1222. run=None,
  1223. log_parameters: bool = True,
  1224. log_checkpoints: Optional[str] = None,
  1225. **neptune_run_kwargs,
  1226. ):
  1227. if not is_neptune_available():
  1228. raise ValueError(
  1229. "NeptuneCallback requires the Neptune client library to be installed. "
  1230. "To install the library, run `pip install neptune`."
  1231. )
  1232. try:
  1233. from neptune import Run
  1234. from neptune.internal.utils import verify_type
  1235. except ImportError:
  1236. from neptune.new.internal.utils import verify_type
  1237. from neptune.new.metadata_containers.run import Run
  1238. verify_type("api_token", api_token, (str, type(None)))
  1239. verify_type("project", project, (str, type(None)))
  1240. verify_type("name", name, (str, type(None)))
  1241. verify_type("base_namespace", base_namespace, str)
  1242. verify_type("run", run, (Run, type(None)))
  1243. verify_type("log_parameters", log_parameters, bool)
  1244. verify_type("log_checkpoints", log_checkpoints, (str, type(None)))
  1245. self._base_namespace_path = base_namespace
  1246. self._log_parameters = log_parameters
  1247. self._log_checkpoints = log_checkpoints
  1248. self._initial_run: Optional[Run] = run
  1249. self._run = None
  1250. self._is_monitoring_run = False
  1251. self._run_id = None
  1252. self._force_reset_monitoring_run = False
  1253. self._init_run_kwargs = {"api_token": api_token, "project": project, "name": name, **neptune_run_kwargs}
  1254. self._volatile_checkpoints_dir = None
  1255. self._should_upload_checkpoint = self._log_checkpoints is not None
  1256. self._recent_checkpoint_path = None
  1257. if self._log_checkpoints in {"last", "best"}:
  1258. self._target_checkpoints_namespace = f"checkpoints/{self._log_checkpoints}"
  1259. self._should_clean_recently_uploaded_checkpoint = True
  1260. else:
  1261. self._target_checkpoints_namespace = "checkpoints"
  1262. self._should_clean_recently_uploaded_checkpoint = False
  1263. def _stop_run_if_exists(self):
  1264. if self._run:
  1265. self._run.stop()
  1266. del self._run
  1267. self._run = None
  1268. def _initialize_run(self, **additional_neptune_kwargs):
  1269. try:
  1270. from neptune import init_run
  1271. from neptune.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1272. except ImportError:
  1273. from neptune.new import init_run
  1274. from neptune.new.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
  1275. self._stop_run_if_exists()
  1276. try:
  1277. run_params = additional_neptune_kwargs.copy()
  1278. run_params.update(self._init_run_kwargs)
  1279. self._run = init_run(**run_params)
  1280. self._run_id = self._run["sys/id"].fetch()
  1281. except (NeptuneMissingProjectNameException, NeptuneMissingApiTokenException) as e:
  1282. raise NeptuneMissingConfiguration() from e
  1283. def _use_initial_run(self):
  1284. self._run = self._initial_run
  1285. self._is_monitoring_run = True
  1286. self._run_id = self._run["sys/id"].fetch()
  1287. self._initial_run = None
  1288. def _ensure_run_with_monitoring(self):
  1289. if self._initial_run is not None:
  1290. self._use_initial_run()
  1291. else:
  1292. if not self._force_reset_monitoring_run and self._is_monitoring_run:
  1293. return
  1294. if self._run and not self._is_monitoring_run and not self._force_reset_monitoring_run:
  1295. self._initialize_run(with_id=self._run_id)
  1296. self._is_monitoring_run = True
  1297. else:
  1298. self._initialize_run()
  1299. self._force_reset_monitoring_run = False
  1300. def _ensure_at_least_run_without_monitoring(self):
  1301. if self._initial_run is not None:
  1302. self._use_initial_run()
  1303. else:
  1304. if not self._run:
  1305. self._initialize_run(
  1306. with_id=self._run_id,
  1307. capture_stdout=False,
  1308. capture_stderr=False,
  1309. capture_hardware_metrics=False,
  1310. capture_traceback=False,
  1311. )
  1312. self._is_monitoring_run = False
  1313. @property
  1314. def run(self):
  1315. if self._run is None:
  1316. self._ensure_at_least_run_without_monitoring()
  1317. return self._run
  1318. @property
  1319. def _metadata_namespace(self):
  1320. return self.run[self._base_namespace_path]
  1321. def _log_integration_version(self):
  1322. self.run[NeptuneCallback.integration_version_key] = version
  1323. def _log_trainer_parameters(self, args):
  1324. self._metadata_namespace[NeptuneCallback.trainer_parameters_key] = args.to_sanitized_dict()
  1325. def _log_model_parameters(self, model):
  1326. from neptune.utils import stringify_unsupported
  1327. if model and hasattr(model, "config") and model.config is not None:
  1328. self._metadata_namespace[NeptuneCallback.model_parameters_key] = stringify_unsupported(
  1329. model.config.to_dict()
  1330. )
  1331. def _log_hyper_param_search_parameters(self, state):
  1332. if state and hasattr(state, "trial_name"):
  1333. self._metadata_namespace[NeptuneCallback.trial_name_key] = state.trial_name
  1334. if state and hasattr(state, "trial_params") and state.trial_params is not None:
  1335. self._metadata_namespace[NeptuneCallback.trial_params_key] = state.trial_params
  1336. def _log_model_checkpoint(self, source_directory: str, checkpoint: str):
  1337. target_path = relative_path = os.path.join(source_directory, checkpoint)
  1338. if self._volatile_checkpoints_dir is not None:
  1339. consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
  1340. try:
  1341. # Remove leading ../ from a relative path.
  1342. cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
  1343. copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
  1344. shutil.copytree(relative_path, copy_path)
  1345. target_path = consistent_checkpoint_path
  1346. except IOError as e:
  1347. logger.warning(
  1348. "NeptuneCallback was unable to made a copy of checkpoint due to I/O exception: '{}'. "
  1349. "Could fail trying to upload.".format(e)
  1350. )
  1351. self._metadata_namespace[self._target_checkpoints_namespace].upload_files(target_path)
  1352. if self._should_clean_recently_uploaded_checkpoint and self._recent_checkpoint_path is not None:
  1353. self._metadata_namespace[self._target_checkpoints_namespace].delete_files(self._recent_checkpoint_path)
  1354. self._recent_checkpoint_path = relative_path
  1355. def on_init_end(self, args, state, control, **kwargs):
  1356. self._volatile_checkpoints_dir = None
  1357. if self._log_checkpoints and (args.overwrite_output_dir or args.save_total_limit is not None):
  1358. self._volatile_checkpoints_dir = tempfile.TemporaryDirectory().name
  1359. if self._log_checkpoints == "best" and not args.load_best_model_at_end:
  1360. raise ValueError("To save the best model checkpoint, the load_best_model_at_end argument must be enabled.")
  1361. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1362. if not state.is_world_process_zero:
  1363. return
  1364. self._ensure_run_with_monitoring()
  1365. self._force_reset_monitoring_run = True
  1366. self._log_integration_version()
  1367. if self._log_parameters:
  1368. self._log_trainer_parameters(args)
  1369. self._log_model_parameters(model)
  1370. if state.is_hyper_param_search:
  1371. self._log_hyper_param_search_parameters(state)
  1372. def on_train_end(self, args, state, control, **kwargs):
  1373. self._stop_run_if_exists()
  1374. def __del__(self):
  1375. if self._volatile_checkpoints_dir is not None:
  1376. shutil.rmtree(self._volatile_checkpoints_dir, ignore_errors=True)
  1377. self._stop_run_if_exists()
  1378. def on_save(self, args, state, control, **kwargs):
  1379. if self._should_upload_checkpoint:
  1380. self._log_model_checkpoint(args.output_dir, f"checkpoint-{state.global_step}")
  1381. def on_evaluate(self, args, state, control, metrics=None, **kwargs):
  1382. if self._log_checkpoints == "best":
  1383. best_metric_name = args.metric_for_best_model
  1384. if not best_metric_name.startswith("eval_"):
  1385. best_metric_name = f"eval_{best_metric_name}"
  1386. metric_value = metrics.get(best_metric_name)
  1387. operator = np.greater if args.greater_is_better else np.less
  1388. self._should_upload_checkpoint = state.best_metric is None or operator(metric_value, state.best_metric)
  1389. @classmethod
  1390. def get_run(cls, trainer):
  1391. for callback in trainer.callback_handler.callbacks:
  1392. if isinstance(callback, cls):
  1393. return callback.run
  1394. raise Exception("The trainer doesn't have a NeptuneCallback configured.")
  1395. def on_log(self, args, state, control, logs: Optional[Dict[str, float]] = None, **kwargs):
  1396. if not state.is_world_process_zero:
  1397. return
  1398. if logs is not None:
  1399. for name, value in rewrite_logs(logs).items():
  1400. if isinstance(value, (int, float)):
  1401. if name in NeptuneCallback.flat_metrics:
  1402. self._metadata_namespace[name] = value
  1403. else:
  1404. self._metadata_namespace[name].log(value, step=state.global_step)
  1405. class CodeCarbonCallback(TrainerCallback):
  1406. """
  1407. A [`TrainerCallback`] that tracks the CO2 emission of training.
  1408. """
  1409. def __init__(self):
  1410. if not is_codecarbon_available():
  1411. raise RuntimeError(
  1412. "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
  1413. )
  1414. elif torch.version.hip:
  1415. raise RuntimeError(
  1416. "CodeCarbonCallback requires `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). When using the Trainer, please specify the `report_to` argument (https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to) to disable CodeCarbonCallback."
  1417. )
  1418. import codecarbon
  1419. self._codecarbon = codecarbon
  1420. self.tracker = None
  1421. def on_init_end(self, args, state, control, **kwargs):
  1422. if self.tracker is None and state.is_local_process_zero:
  1423. # CodeCarbon will automatically handle environment variables for configuration
  1424. self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)
  1425. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1426. if self.tracker and state.is_local_process_zero:
  1427. self.tracker.start()
  1428. def on_train_end(self, args, state, control, **kwargs):
  1429. if self.tracker and state.is_local_process_zero:
  1430. self.tracker.stop()
  1431. class ClearMLCallback(TrainerCallback):
  1432. """
  1433. A [`TrainerCallback`] that sends the logs to [ClearML](https://clear.ml/).
  1434. Environment:
  1435. - **CLEARML_PROJECT** (`str`, *optional*, defaults to `HuggingFace Transformers`):
  1436. ClearML project name.
  1437. - **CLEARML_TASK** (`str`, *optional*, defaults to `Trainer`):
  1438. ClearML task name.
  1439. - **CLEARML_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
  1440. Whether to log models as artifacts during training.
  1441. """
  1442. log_suffix = ""
  1443. _hparams_section = "Transformers"
  1444. _model_config_section = "Model Configuration"
  1445. _ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
  1446. _ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
  1447. _model_config_description = "The configuration of model number {}."
  1448. _model_config_description_note = (
  1449. "Note that, when cloning this task and running it remotely,"
  1450. " the configuration might be applied to another model instead of this one."
  1451. " To avoid this, initialize the task externally by calling `Task.init`"
  1452. " before the `ClearMLCallback` is instantiated."
  1453. )
  1454. _train_run_counter = 0
  1455. _model_connect_counter = 0
  1456. _task_created_in_callback = False
  1457. _should_close_on_train_end = None
  1458. def __init__(self):
  1459. if is_clearml_available():
  1460. import clearml
  1461. self._clearml = clearml
  1462. else:
  1463. raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")
  1464. self._initialized = False
  1465. self._clearml_task = None
  1466. self._log_model = False
  1467. self._checkpoints_saved = []
  1468. def setup(self, args, state, model, tokenizer, **kwargs):
  1469. if self._clearml is None:
  1470. return
  1471. if self._initialized:
  1472. return
  1473. ClearMLCallback._train_run_counter += 1
  1474. ClearMLCallback._model_connect_counter += 1
  1475. ClearMLCallback.log_suffix = (
  1476. "" if ClearMLCallback._train_run_counter == 1 else "_" + str(ClearMLCallback._train_run_counter)
  1477. )
  1478. if state.is_world_process_zero:
  1479. logger.info("Automatic ClearML logging enabled.")
  1480. if self._clearml_task is None:
  1481. if ClearMLCallback._should_close_on_train_end is None:
  1482. if not self._clearml.Task.running_locally() or self._clearml.Task.current_task():
  1483. ClearMLCallback._should_close_on_train_end = False
  1484. else:
  1485. ClearMLCallback._should_close_on_train_end = True
  1486. # This might happen when running inside of a pipeline, where the task is already initialized
  1487. # from outside of Hugging Face
  1488. if self._clearml.Task.running_locally() and self._clearml.Task.current_task():
  1489. self._clearml_task = self._clearml.Task.current_task()
  1490. self._log_model = os.getenv(
  1491. "CLEARML_LOG_MODEL",
  1492. "FALSE" if not ClearMLCallback._task_created_in_callback else "TRUE",
  1493. ).upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
  1494. logger.info("External ClearML Task has been connected.")
  1495. else:
  1496. self._clearml_task = self._clearml.Task.init(
  1497. project_name=os.getenv("CLEARML_PROJECT", "HuggingFace Transformers"),
  1498. task_name=os.getenv("CLEARML_TASK", "Trainer"),
  1499. auto_connect_frameworks={"tensorboard": False, "pytorch": False},
  1500. output_uri=True,
  1501. )
  1502. self._log_model = os.getenv("CLEARML_LOG_MODEL", "TRUE").upper() in ENV_VARS_TRUE_VALUES.union(
  1503. {"TRUE"}
  1504. )
  1505. ClearMLCallback._task_created_in_callback = True
  1506. logger.info("ClearML Task has been initialized.")
  1507. self._initialized = True
  1508. suffixed_hparams_section = ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1509. ignore_hparams_config_section = suffixed_hparams_section + "/" + ClearMLCallback._ignore_hparams_overrides
  1510. if self._clearml.Task.running_locally():
  1511. self._copy_training_args_as_hparams(args, suffixed_hparams_section)
  1512. self._clearml_task.set_parameter(
  1513. name=ignore_hparams_config_section,
  1514. value=True,
  1515. value_type=bool,
  1516. description=(
  1517. "If True, ignore Transformers hyperparameters overrides done in the UI/backend "
  1518. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1519. ),
  1520. )
  1521. elif not self._clearml_task.get_parameter(ignore_hparams_config_section, default=True, cast=True):
  1522. self._clearml_task.connect(args, suffixed_hparams_section)
  1523. else:
  1524. self._copy_training_args_as_hparams(
  1525. args, ClearMLCallback._hparams_section + ClearMLCallback.log_suffix
  1526. )
  1527. if getattr(model, "config", None) is not None:
  1528. ignore_model_config_section = (
  1529. suffixed_hparams_section + "/" + ClearMLCallback._ignoge_model_config_overrides
  1530. )
  1531. configuration_object_description = ClearMLCallback._model_config_description.format(
  1532. ClearMLCallback._model_connect_counter
  1533. )
  1534. if ClearMLCallback._model_connect_counter != ClearMLCallback._train_run_counter:
  1535. configuration_object_description += " " + ClearMLCallback._model_config_description_note
  1536. if self._clearml.Task.running_locally():
  1537. self._clearml_task.set_parameter(
  1538. name=ignore_model_config_section,
  1539. value=True,
  1540. value_type=bool,
  1541. description=(
  1542. "If True, ignore Transformers model configuration overrides done in the UI/backend "
  1543. + "when running remotely. Otherwise, the overrides will be applied when running remotely"
  1544. ),
  1545. )
  1546. self._clearml_task.set_configuration_object(
  1547. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1548. config_dict=model.config.to_dict(),
  1549. description=configuration_object_description,
  1550. )
  1551. elif not self._clearml_task.get_parameter(ignore_model_config_section, default=True, cast=True):
  1552. model.config = model.config.from_dict(
  1553. self._clearml_task.get_configuration_object_as_dict(
  1554. ClearMLCallback._model_config_section + ClearMLCallback.log_suffix
  1555. )
  1556. )
  1557. else:
  1558. self._clearml_task.set_configuration_object(
  1559. name=ClearMLCallback._model_config_section + ClearMLCallback.log_suffix,
  1560. config_dict=model.config.to_dict(),
  1561. description=configuration_object_description,
  1562. )
  1563. def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
  1564. if self._clearml is None:
  1565. return
  1566. self._checkpoints_saved = []
  1567. if state.is_hyper_param_search:
  1568. self._initialized = False
  1569. if not self._initialized:
  1570. self.setup(args, state, model, tokenizer, **kwargs)
  1571. def on_train_end(self, args, state, control, **kwargs):
  1572. if ClearMLCallback._should_close_on_train_end:
  1573. self._clearml_task.close()
  1574. ClearMLCallback._train_run_counter = 0
  1575. def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
  1576. if self._clearml is None:
  1577. return
  1578. if not self._initialized:
  1579. self.setup(args, state, model, tokenizer, **kwargs)
  1580. if state.is_world_process_zero:
  1581. eval_prefix = "eval_"
  1582. eval_prefix_len = len(eval_prefix)
  1583. test_prefix = "test_"
  1584. test_prefix_len = len(test_prefix)
  1585. single_value_scalars = [
  1586. "train_runtime",
  1587. "train_samples_per_second",
  1588. "train_steps_per_second",
  1589. "train_loss",
  1590. "total_flos",
  1591. "epoch",
  1592. ]
  1593. for k, v in logs.items():
  1594. if isinstance(v, (int, float)):
  1595. if k in single_value_scalars:
  1596. self._clearml_task.get_logger().report_single_value(
  1597. name=k + ClearMLCallback.log_suffix, value=v
  1598. )
  1599. elif k.startswith(eval_prefix):
  1600. self._clearml_task.get_logger().report_scalar(
  1601. title="eval" + ClearMLCallback.log_suffix,
  1602. series=k[eval_prefix_len:],
  1603. value=v,
  1604. iteration=state.global_step,
  1605. )
  1606. elif k.startswith(test_prefix):
  1607. self._clearml_task.get_logger().report_scalar(
  1608. title="test" + ClearMLCallback.log_suffix,
  1609. series=k[test_prefix_len:],
  1610. value=v,
  1611. iteration=state.global_step,
  1612. )
  1613. else:
  1614. self._clearml_task.get_logger().report_scalar(
  1615. title="train" + ClearMLCallback.log_suffix,
  1616. series=k,
  1617. value=v,
  1618. iteration=state.global_step,
  1619. )
  1620. else:
  1621. logger.warning(
  1622. "Trainer is attempting to log a value of "
  1623. f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
  1624. "This invocation of ClearML logger's report_scalar() "
  1625. "is incorrect so we dropped this attribute."
  1626. )
  1627. def on_save(self, args, state, control, **kwargs):
  1628. if self._log_model and self._clearml_task and state.is_world_process_zero:
  1629. ckpt_dir = f"checkpoint-{state.global_step}"
  1630. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1631. name = ckpt_dir + ClearMLCallback.log_suffix
  1632. logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
  1633. output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
  1634. output_model.connect(task=self._clearml_task, name=name)
  1635. output_model.update_weights_package(
  1636. weights_path=artifact_path,
  1637. target_filename=ckpt_dir,
  1638. iteration=state.global_step,
  1639. auto_delete_file=False,
  1640. )
  1641. self._checkpoints_saved.append(output_model)
  1642. while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
  1643. try:
  1644. self._clearml.model.Model.remove(
  1645. self._checkpoints_saved[0],
  1646. delete_weights_file=True,
  1647. force=True,
  1648. raise_on_errors=True,
  1649. )
  1650. except Exception as e:
  1651. logger.warning(
  1652. "Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
  1653. self._checkpoints_saved[0].name, e
  1654. )
  1655. )
  1656. break
  1657. self._checkpoints_saved = self._checkpoints_saved[1:]
  1658. def _copy_training_args_as_hparams(self, training_args, prefix):
  1659. as_dict = {
  1660. field.name: getattr(training_args, field.name)
  1661. for field in fields(training_args)
  1662. if field.init and not field.name.endswith("_token")
  1663. }
  1664. flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
  1665. self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
  1666. class FlyteCallback(TrainerCallback):
  1667. """A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
  1668. NOTE: This callback only works within a Flyte task.
  1669. Args:
  1670. save_log_history (`bool`, *optional*, defaults to `True`):
  1671. When set to True, the training logs are saved as a Flyte Deck.
  1672. sync_checkpoints (`bool`, *optional*, defaults to `True`):
  1673. When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
  1674. interruption.
  1675. Example:
  1676. ```python
  1677. # Note: This example skips over some setup steps for brevity.
  1678. from flytekit import current_context, task
  1679. @task
  1680. def train_hf_transformer():
  1681. cp = current_context().checkpoint
  1682. trainer = Trainer(..., callbacks=[FlyteCallback()])
  1683. output = trainer.train(resume_from_checkpoint=cp.restore())
  1684. ```
  1685. """
  1686. def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
  1687. super().__init__()
  1688. if not is_flytekit_available():
  1689. raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
  1690. if not is_flyte_deck_standard_available() or not is_pandas_available():
  1691. logger.warning(
  1692. "Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
  1693. "Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
  1694. )
  1695. save_log_history = False
  1696. from flytekit import current_context
  1697. self.cp = current_context().checkpoint
  1698. self.save_log_history = save_log_history
  1699. self.sync_checkpoints = sync_checkpoints
  1700. def on_save(self, args, state, control, **kwargs):
  1701. if self.sync_checkpoints and state.is_world_process_zero:
  1702. ckpt_dir = f"checkpoint-{state.global_step}"
  1703. artifact_path = os.path.join(args.output_dir, ckpt_dir)
  1704. logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
  1705. self.cp.save(artifact_path)
  1706. def on_train_end(self, args, state, control, **kwargs):
  1707. if self.save_log_history:
  1708. import pandas as pd
  1709. from flytekit import Deck
  1710. from flytekitplugins.deck.renderer import TableRenderer
  1711. log_history_df = pd.DataFrame(state.log_history)
  1712. Deck("Log History", TableRenderer().to_html(log_history_df))
  1713. class DVCLiveCallback(TrainerCallback):
  1714. """
  1715. A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
  1716. Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
  1717. those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1718. Args:
  1719. live (`dvclive.Live`, *optional*, defaults to `None`):
  1720. Optional Live instance. If None, a new instance will be created using **kwargs.
  1721. log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
  1722. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
  1723. the final checkpoint is logged at the end of training. If set to `"all"`, the entire
  1724. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1725. """
  1726. def __init__(
  1727. self,
  1728. live: Optional[Any] = None,
  1729. log_model: Optional[Union[Literal["all"], bool]] = None,
  1730. **kwargs,
  1731. ):
  1732. if not is_dvclive_available():
  1733. raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
  1734. from dvclive import Live
  1735. self._initialized = False
  1736. self.live = None
  1737. if isinstance(live, Live):
  1738. self.live = live
  1739. elif live is not None:
  1740. raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
  1741. self._log_model = log_model
  1742. if self._log_model is None:
  1743. log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
  1744. if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
  1745. self._log_model = True
  1746. elif log_model_env.lower() == "all":
  1747. self._log_model = "all"
  1748. def setup(self, args, state, model):
  1749. """
  1750. Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
  1751. [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
  1752. Environment:
  1753. - **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
  1754. Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
  1755. *1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
  1756. [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
  1757. """
  1758. from dvclive import Live
  1759. self._initialized = True
  1760. if state.is_world_process_zero:
  1761. if not self.live:
  1762. self.live = Live()
  1763. self.live.log_params(args.to_dict())
  1764. def on_train_begin(self, args, state, control, model=None, **kwargs):
  1765. if not self._initialized:
  1766. self.setup(args, state, model)
  1767. def on_log(self, args, state, control, model=None, logs=None, **kwargs):
  1768. if not self._initialized:
  1769. self.setup(args, state, model)
  1770. if state.is_world_process_zero:
  1771. from dvclive.plots import Metric
  1772. from dvclive.utils import standardize_metric_name
  1773. for key, value in logs.items():
  1774. if Metric.could_log(value):
  1775. self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
  1776. else:
  1777. logger.warning(
  1778. "Trainer is attempting to log a value of "
  1779. f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
  1780. "This invocation of DVCLive's Live.log_metric() "
  1781. "is incorrect so we dropped this attribute."
  1782. )
  1783. self.live.next_step()
  1784. def on_save(self, args, state, control, **kwargs):
  1785. if self._log_model == "all" and self._initialized and state.is_world_process_zero:
  1786. self.live.log_artifact(args.output_dir)
  1787. def on_train_end(self, args, state, control, **kwargs):
  1788. if self._initialized and state.is_world_process_zero:
  1789. from transformers.trainer import Trainer
  1790. if self._log_model is True:
  1791. fake_trainer = Trainer(args=args, model=kwargs.get("model"), processing_class=kwargs.get("tokenizer"))
  1792. name = "best" if args.load_best_model_at_end else "last"
  1793. output_dir = os.path.join(args.output_dir, name)
  1794. fake_trainer.save_model(output_dir)
  1795. self.live.log_artifact(output_dir, name=name, type="model", copy=True)
  1796. self.live.end()
  1797. INTEGRATION_TO_CALLBACK = {
  1798. "azure_ml": AzureMLCallback,
  1799. "comet_ml": CometCallback,
  1800. "mlflow": MLflowCallback,
  1801. "neptune": NeptuneCallback,
  1802. "tensorboard": TensorBoardCallback,
  1803. "wandb": WandbCallback,
  1804. "codecarbon": CodeCarbonCallback,
  1805. "clearml": ClearMLCallback,
  1806. "dagshub": DagsHubCallback,
  1807. "flyte": FlyteCallback,
  1808. "dvclive": DVCLiveCallback,
  1809. }
  1810. def get_reporting_integration_callbacks(report_to):
  1811. for integration in report_to:
  1812. if integration not in INTEGRATION_TO_CALLBACK:
  1813. raise ValueError(
  1814. f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
  1815. )
  1816. return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]