| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062 |
- # coding=utf-8
- # Copyright 2020-present the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
- """
- import contextlib
- import copy
- import functools
- import glob
- import importlib.metadata
- import inspect
- import json
- import math
- import os
- import random
- import re
- import shutil
- import sys
- import tempfile
- import time
- import warnings
- from collections.abc import Mapping
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
- # Integrations must be imported before ML frameworks:
- # isort: off
- from .integrations import (
- get_reporting_integration_callbacks,
- hp_params,
- )
- # isort: on
- import huggingface_hub.utils as hf_hub_utils
- import numpy as np
- import torch
- import torch.distributed as dist
- from huggingface_hub import ModelCard, create_repo, upload_folder
- from packaging import version
- from torch import nn
- from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
- from . import __version__
- from .configuration_utils import PretrainedConfig
- from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
- from .debug_utils import DebugOption, DebugUnderflowOverflow
- from .feature_extraction_sequence_utils import SequenceFeatureExtractor
- from .feature_extraction_utils import FeatureExtractionMixin
- from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
- from .image_processing_utils import BaseImageProcessor
- from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
- from .integrations.tpu import tpu_spmd_dataloader
- from .modelcard import TrainingSummary
- from .modeling_utils import PreTrainedModel, load_sharded_checkpoint
- from .models.auto.modeling_auto import (
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
- MODEL_MAPPING_NAMES,
- )
- from .optimization import Adafactor, get_scheduler
- from .processing_utils import ProcessorMixin
- from .pytorch_utils import (
- ALL_LAYERNORM_LAYERS,
- is_torch_greater_or_equal_than_1_13,
- is_torch_greater_or_equal_than_2_3,
- )
- from .tokenization_utils_base import PreTrainedTokenizerBase
- from .trainer_callback import (
- CallbackHandler,
- DefaultFlowCallback,
- ExportableState,
- PrinterCallback,
- ProgressCallback,
- TrainerCallback,
- TrainerControl,
- TrainerState,
- )
- from .trainer_pt_utils import (
- DistributedTensorGatherer,
- EvalLoopContainer,
- IterableDatasetShard,
- LabelSmoother,
- LayerWiseDummyOptimizer,
- LengthGroupedSampler,
- SequentialDistributedSampler,
- distributed_broadcast_scalars,
- distributed_concat,
- find_batch_size,
- get_model_param_count,
- get_module_class_from_name,
- get_parameter_names,
- nested_concat,
- nested_detach,
- nested_numpify,
- nested_xla_mesh_reduce,
- reissue_pt_warnings,
- remove_dummy_checkpoint,
- )
- from .trainer_utils import (
- PREFIX_CHECKPOINT_DIR,
- BestRun,
- EvalLoopOutput,
- EvalPrediction,
- HPSearchBackend,
- HubStrategy,
- IntervalStrategy,
- PredictionOutput,
- RemoveColumnsCollator,
- TrainerMemoryTracker,
- TrainOutput,
- check_target_module_exists,
- default_compute_objective,
- denumpify_detensorize,
- enable_full_determinism,
- find_executable_batch_size,
- get_last_checkpoint,
- has_length,
- neftune_post_forward_hook,
- number_of_arguments,
- seed_worker,
- set_seed,
- speed_metrics,
- )
- from .training_args import OptimizerNames, ParallelMode, TrainingArguments
- from .utils import (
- ADAPTER_CONFIG_NAME,
- ADAPTER_SAFE_WEIGHTS_NAME,
- ADAPTER_WEIGHTS_NAME,
- CONFIG_NAME,
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- XLA_FSDPV2_MIN_VERSION,
- PushInProgress,
- PushToHubMixin,
- can_return_loss,
- find_labels,
- is_accelerate_available,
- is_apex_available,
- is_bitsandbytes_available,
- is_datasets_available,
- is_galore_torch_available,
- is_grokadamw_available,
- is_in_notebook,
- is_ipex_available,
- is_liger_kernel_available,
- is_lomo_available,
- is_peft_available,
- is_safetensors_available,
- is_sagemaker_dp_enabled,
- is_sagemaker_mp_enabled,
- is_schedulefree_available,
- is_torch_compile_available,
- is_torch_mlu_available,
- is_torch_mps_available,
- is_torch_musa_available,
- is_torch_neuroncore_available,
- is_torch_npu_available,
- is_torch_xla_available,
- is_torch_xpu_available,
- is_torchao_available,
- logging,
- strtobool,
- )
- from .utils.deprecation import deprecate_kwarg
- from .utils.quantization_config import QuantizationMethod
- DEFAULT_CALLBACKS = [DefaultFlowCallback]
- DEFAULT_PROGRESS_CALLBACK = ProgressCallback
- if is_in_notebook():
- from .utils.notebook import NotebookProgressCallback
- DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
- if is_apex_available():
- from apex import amp
- if is_datasets_available():
- import datasets
- if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
- import torch_xla.debug.metrics as met
- from torch_xla import __version__ as XLA_VERSION
- IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
- if IS_XLA_FSDPV2_POST_2_2:
- import torch_xla.distributed.spmd as xs
- import torch_xla.runtime as xr
- else:
- IS_XLA_FSDPV2_POST_2_2 = False
- if is_sagemaker_mp_enabled():
- import smdistributed.modelparallel.torch as smp
- from smdistributed.modelparallel import __version__ as SMP_VERSION
- IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
- from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
- else:
- IS_SAGEMAKER_MP_POST_1_10 = False
- if is_safetensors_available():
- import safetensors.torch
- if is_peft_available():
- from peft import PeftModel
- if is_accelerate_available():
- from accelerate import Accelerator, skip_first_batches
- from accelerate import __version__ as accelerate_version
- from accelerate.state import AcceleratorState
- from accelerate.utils import (
- DistributedDataParallelKwargs,
- DistributedType,
- load_fsdp_model,
- load_fsdp_optimizer,
- save_fsdp_model,
- save_fsdp_optimizer,
- )
- DATA_SAMPLERS = [RandomSampler]
- if version.parse(accelerate_version) > version.parse("0.23.0"):
- from accelerate.data_loader import SeedableRandomSampler
- DATA_SAMPLERS += [SeedableRandomSampler]
- if is_deepspeed_available():
- from accelerate.utils import DeepSpeedSchedulerWrapper
- if is_accelerate_available("0.28.0"):
- from accelerate.utils import DataLoaderConfiguration
- def _is_peft_model(model):
- if is_peft_available():
- classes_to_check = (PeftModel,) if is_peft_available() else ()
- # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
- if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
- from peft import PeftMixedModel
- classes_to_check = (*classes_to_check, PeftMixedModel)
- return isinstance(model, classes_to_check)
- return False
- def _get_fsdp_ckpt_kwargs():
- # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
- if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
- return {"adapter_only": True}
- else:
- return {}
- if TYPE_CHECKING:
- import optuna
- if is_datasets_available():
- import datasets
- logger = logging.get_logger(__name__)
- # Name of the files used for checkpointing
- TRAINING_ARGS_NAME = "training_args.bin"
- TRAINER_STATE_NAME = "trainer_state.json"
- OPTIMIZER_NAME = "optimizer.pt"
- OPTIMIZER_NAME_BIN = "optimizer.bin"
- SCHEDULER_NAME = "scheduler.pt"
- SCALER_NAME = "scaler.pt"
- FSDP_MODEL_NAME = "pytorch_model_fsdp"
- class Trainer:
- """
- Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
- Args:
- model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
- The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
- <Tip>
- [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
- your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
- models.
- </Tip>
- args ([`TrainingArguments`], *optional*):
- The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
- `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
- data_collator (`DataCollator`, *optional*):
- The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
- default to [`default_data_collator`] if no `processing_class` is provided, an instance of
- [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
- train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
- The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed.
- Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
- distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
- `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
- manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
- sets the seed of the RNGs used.
- eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
- The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
- dataset prepending the dictionary key to the metric name.
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
- Processing class used to process the data. If provided, will be used to automatically process the inputs
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
- reuse the fine-tuned model.
- This supercedes the `tokenizer` argument, which is now deprecated.
- model_init (`Callable[[], PreTrainedModel]`, *optional*):
- A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
- from a new instance of the model as given by this function.
- The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
- be able to choose different architectures according to hyper parameters (such as layer count, sizes of
- inner layers, dropout probabilities etc).
- compute_loss_func (`Callable`, *optional*):
- A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
- batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, here is one using
- the loss function from `transformers`
- compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
- The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
- a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
- `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
- after the last eval batch to signal that the function needs to calculate and return the global summary
- statistics rather than accumulating the batch-level statistics
- callbacks (List of [`TrainerCallback`], *optional*):
- A list of callbacks to customize the training loop. Will add those to the list of default callbacks
- detailed in [here](callback).
- If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
- optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
- by this function will be reflected in the predictions received by `compute_metrics`.
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
- Important attributes:
- - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
- subclass.
- - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
- original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
- the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
- model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
- data parallelism, this means some of the model layers are split on different GPUs).
- - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
- to `False` if model parallel or deepspeed is used, or if the default
- `TrainingArguments.place_model_on_device` is overridden to return `False` .
- - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
- in `train`)
- """
- # Those are used as methods of the Trainer in examples.
- from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
- @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
- def __init__(
- self,
- model: Union[PreTrainedModel, nn.Module] = None,
- args: TrainingArguments = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
- eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
- processing_class: Optional[
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
- ] = None,
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
- compute_loss_func: Optional[Callable] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
- callbacks: Optional[List[TrainerCallback]] = None,
- optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- ):
- if args is None:
- output_dir = "tmp_trainer"
- logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
- args = TrainingArguments(output_dir=output_dir)
- if args.batch_eval_metrics and compute_metrics is not None:
- if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
- raise ValueError(
- "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
- " boolean argument which will be triggered after the last batch of the eval set to signal that the"
- " summary statistics should be returned by the function."
- )
- if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None:
- raise ValueError(
- f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
- )
- self.args = args
- self.compute_loss_func = compute_loss_func
- # Seed must be set before instantiating the model when using model
- enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
- self.hp_name = None
- self.deepspeed = None
- self.is_in_train = False
- self.create_accelerator_and_postprocess()
- # memory metrics - must set up as early as possible
- self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
- self._memory_tracker.start()
- # set the correct log level depending on the node
- log_level = args.get_process_log_level()
- logging.set_verbosity(log_level)
- # force device and distributed setup init explicitly
- args._setup_devices
- if model is None:
- if model_init is not None:
- self.model_init = model_init
- model = self.call_model_init()
- else:
- raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
- else:
- if model_init is not None:
- warnings.warn(
- "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
- " overwrite your model when calling the `train` method. This will become a fatal error in the next"
- " release.",
- FutureWarning,
- )
- self.model_init = model_init
- if model.__class__.__name__ in MODEL_MAPPING_NAMES:
- raise ValueError(
- f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
- "computes hidden states and does not accept any labels. You should choose a model with a head "
- "suitable for your task like any of the `AutoModelForXxx` listed at "
- "https://huggingface.co/docs/transformers/model_doc/auto"
- )
- if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
- self.is_model_parallel = True
- else:
- self.is_model_parallel = False
- if getattr(model, "hf_device_map", None) is not None:
- devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
- if len(devices) > 1:
- self.is_model_parallel = True
- elif len(devices) == 1:
- self.is_model_parallel = self.args.device != torch.device(devices[0])
- else:
- self.is_model_parallel = False
- # warn users
- if self.is_model_parallel:
- logger.info(
- "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
- " to `True` to avoid any unexpected behavior such as device placement mismatching."
- )
- if self.args.use_liger_kernel:
- if is_liger_kernel_available():
- from liger_kernel.transformers import _apply_liger_kernel_to_instance
- if isinstance(model, PreTrainedModel):
- # Patch the model with liger kernels. Use the default kernel configurations.
- _apply_liger_kernel_to_instance(model=model)
- else:
- logger.warning(
- "The model is not an instance of PreTrainedModel. No liger kernels will be applied."
- )
- else:
- raise ImportError(
- "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
- "Please install it with `pip install liger-kernel`"
- )
- _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
- model, "_hf_peft_config_loaded", False
- )
- _quantization_method_supports_training = (
- getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
- )
- # Filter out quantized + compiled models
- if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
- raise ValueError(
- "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
- )
- # At this stage the model is already loaded
- if _is_quantized_and_base_model and not _is_peft_model(model):
- raise ValueError(
- "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
- " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
- " for more details"
- )
- elif _is_quantized_and_base_model and not _quantization_method_supports_training:
- raise ValueError(
- f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
- " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
- f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
- )
- self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
- if len(args.fsdp) > 0:
- if self.is_deepspeed_enabled:
- raise ValueError(
- "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
- )
- if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
- raise ValueError("Using fsdp only works in distributed training.")
- # one place to sort out whether to place the model on device or not
- # postpone switching model to cuda when:
- # 1. MP - since we are trying to fit a much bigger than 1 gpu model
- # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
- # and we only use deepspeed for training at the moment
- # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
- # 4. FSDP - same as MP
- self.place_model_on_device = args.place_model_on_device
- if (
- self.is_model_parallel
- or self.is_deepspeed_enabled
- or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
- or self.is_fsdp_xla_enabled
- or self.is_fsdp_enabled
- ):
- self.place_model_on_device = False
- default_collator = (
- DataCollatorWithPadding(processing_class)
- if processing_class is not None
- and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
- else default_data_collator
- )
- self.data_collator = data_collator if data_collator is not None else default_collator
- self.train_dataset = train_dataset
- self.eval_dataset = eval_dataset
- self.processing_class = processing_class
- # Bnb Quantized models doesn't support `.to` operation.
- if (
- self.place_model_on_device
- and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
- ):
- self._move_model_to_device(model, args.device)
- # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
- if self.is_model_parallel:
- self.args._n_gpu = 1
- # later use `self.model is self.model_wrapped` to check if it's wrapped or not
- self.model_wrapped = model
- self.model = model
- # Just in case the model was wrapped outside of the `Trainer`
- unwrapped_model = self.accelerator.unwrap_model(model)
- model_forward = (
- unwrapped_model.forward
- if not _is_peft_model(unwrapped_model)
- else unwrapped_model.get_base_model().forward
- )
- forward_params = inspect.signature(model_forward).parameters
- self.model_accepts_loss_kwargs = (
- "loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
- )
- self.neftune_noise_alpha = args.neftune_noise_alpha
- self.compute_metrics = compute_metrics
- self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
- self.optimizer, self.lr_scheduler = optimizers
- if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
- raise RuntimeError(
- "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
- "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
- )
- if is_torch_xla_available() and self.optimizer is not None:
- for param in self.model.parameters():
- model_device = param.device
- break
- for param_group in self.optimizer.param_groups:
- if len(param_group["params"]) > 0:
- optimizer_device = param_group["params"][0].device
- break
- if model_device != optimizer_device:
- raise ValueError(
- "The model and the optimizer parameters are not on the same device, which probably means you"
- " created an optimizer around your model **before** putting on the device and passing it to the"
- " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
- " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
- )
- if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
- self.optimizer is not None or self.lr_scheduler is not None
- ):
- raise RuntimeError(
- "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
- "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
- )
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
- callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
- self.callback_handler = CallbackHandler(
- callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
- )
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
- # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
- self._loggers_initialized = False
- # Create distant repo and output directory if needed
- self.hub_model_id = None
- if self.args.push_to_hub:
- self.init_hf_repo()
- if self.args.should_save:
- os.makedirs(self.args.output_dir, exist_ok=True)
- if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
- raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
- if args.max_steps > 0 and args.num_train_epochs > 0:
- logger.warning("max_steps is given, it will override any value given in num_train_epochs")
- if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
- raise ValueError(
- "The train_dataset does not implement __len__, max_steps has to be specified. "
- "The number of steps needs to be known in advance for the learning rate scheduler."
- )
- if (
- train_dataset is not None
- and isinstance(train_dataset, torch.utils.data.IterableDataset)
- and args.group_by_length
- ):
- raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
- self._signature_columns = None
- # Mixed precision setup
- self.use_apex = False
- self.use_cpu_amp = False
- # Mixed precision setup for SageMaker Model Parallel
- if is_sagemaker_mp_enabled():
- # BF16 + model parallelism in SageMaker: currently not supported, raise an error
- if args.bf16:
- raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
- if IS_SAGEMAKER_MP_POST_1_10:
- # When there's mismatch between SMP config and trainer argument, use SMP config as truth
- if args.fp16 != smp.state.cfg.fp16:
- logger.warning(
- f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
- f"but FP16 provided in trainer argument is {args.fp16}, "
- f"setting to {smp.state.cfg.fp16}"
- )
- args.fp16 = smp.state.cfg.fp16
- else:
- # smp < 1.10 does not support fp16 in trainer.
- if hasattr(smp.state.cfg, "fp16"):
- logger.warning(
- f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
- "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
- )
- if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
- if args.device == torch.device("cpu"):
- if args.fp16:
- if not is_torch_greater_or_equal_than_2_3:
- raise ValueError("Tried to use `fp16` but it is not supported on cpu")
- else:
- args.half_precision_backend = "cpu_amp"
- logger.info(f"Using {args.half_precision_backend} half precision backend")
- if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
- # deepspeed and SageMaker Model Parallel manage their own half precision
- if args.half_precision_backend == "cpu_amp":
- self.use_cpu_amp = True
- self.amp_dtype = torch.bfloat16
- elif args.half_precision_backend == "apex":
- if not is_apex_available():
- raise ImportError(
- "Using FP16 with APEX but APEX is not installed, please refer to"
- " https://www.github.com/nvidia/apex."
- )
- self.use_apex = True
- # Label smoothing
- if self.args.label_smoothing_factor != 0:
- self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
- else:
- self.label_smoother = None
- self.control = TrainerControl()
- self.state = TrainerState(
- is_local_process_zero=self.is_local_process_zero(),
- is_world_process_zero=self.is_world_process_zero(),
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ],
- )
- # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
- # returned to 0 every time flos need to be logged
- self.current_flos = 0
- self.hp_search_backend = None
- default_label_names = find_labels(self.model.__class__)
- self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
- self.can_return_loss = can_return_loss(self.model.__class__)
- self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
- # Internal variables to help with automatic batch size reduction
- self._train_batch_size = args.train_batch_size
- self._created_lr_scheduler = False
- # very last
- self._memory_tracker.stop_and_update_metrics()
- # torch.compile
- if args.torch_compile and not is_torch_compile_available():
- raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
- self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
- if self.is_fsdp_xla_v2_enabled:
- if not IS_XLA_FSDPV2_POST_2_2:
- raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
- # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
- # Tensor axis is just a placeholder where it will not be used in FSDPv2.
- num_devices = xr.global_runtime_device_count()
- xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
- self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
- @property
- def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
- logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.")
- return self.processing_class
- @tokenizer.setter
- def tokenizer(self, processing_class) -> None:
- logger.warning(
- "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead."
- )
- self.processing_class = processing_class
- def _activate_neftune(self, model):
- r"""
- Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
- https://arxiv.org/abs/2310.05914
- """
- unwrapped_model = self.accelerator.unwrap_model(model)
- if _is_peft_model(unwrapped_model):
- embeddings = unwrapped_model.base_model.model.get_input_embeddings()
- else:
- embeddings = unwrapped_model.get_input_embeddings()
- del unwrapped_model
- embeddings.neftune_noise_alpha = self.neftune_noise_alpha
- hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
- self.neftune_hook_handle = hook_handle
- return model
- def _deactivate_neftune(self, model):
- """
- Deactivates the neftune method. Make sure to call `_activate_neftune` first.
- """
- if not hasattr(self, "neftune_hook_handle"):
- raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
- unwrapped_model = self.accelerator.unwrap_model(model)
- if _is_peft_model(unwrapped_model):
- embeddings = unwrapped_model.base_model.model.get_input_embeddings()
- else:
- embeddings = unwrapped_model.get_input_embeddings()
- self.neftune_hook_handle.remove()
- del embeddings.neftune_noise_alpha, unwrapped_model
- def add_callback(self, callback):
- """
- Add a callback to the current list of [`~transformers.TrainerCallback`].
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will instantiate a member of that class.
- """
- self.callback_handler.add_callback(callback)
- def pop_callback(self, callback):
- """
- Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
- If the callback is not found, returns `None` (and no error is raised).
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will pop the first member of that class found in the list of callbacks.
- Returns:
- [`~transformers.TrainerCallback`]: The callback removed, if found.
- """
- return self.callback_handler.pop_callback(callback)
- def remove_callback(self, callback):
- """
- Remove a callback from the current list of [`~transformers.TrainerCallback`].
- Args:
- callback (`type` or [`~transformers.TrainerCallback]`):
- A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
- first case, will remove the first member of that class found in the list of callbacks.
- """
- self.callback_handler.remove_callback(callback)
- def _move_model_to_device(self, model, device):
- model = model.to(device)
- # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
- if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
- model.tie_weights()
- def _set_signature_columns_if_needed(self):
- if self._signature_columns is None:
- # Inspect model forward signature to keep only the arguments it accepts.
- model_to_inspect = self.model
- if _is_peft_model(self.model):
- if hasattr(self.model, "get_base_model"):
- model_to_inspect = self.model.get_base_model()
- else:
- # PeftMixedModel do not provide a `get_base_model` method
- model_to_inspect = self.model.base_model.model
- signature = inspect.signature(model_to_inspect.forward)
- self._signature_columns = list(signature.parameters.keys())
- # Labels may be named label or label_ids, the default data collator handles that.
- self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
- def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
- if not self.args.remove_unused_columns:
- return dataset
- self._set_signature_columns_if_needed()
- signature_columns = self._signature_columns
- ignored_columns = list(set(dataset.column_names) - set(signature_columns))
- if len(ignored_columns) > 0:
- dset_description = "" if description is None else f"in the {description} set"
- logger.info(
- f"The following columns {dset_description} don't have a corresponding argument in "
- f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
- f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
- " you can safely ignore this message."
- )
- columns = [k for k in signature_columns if k in dataset.column_names]
- if len(columns) == 0:
- raise ValueError(
- "No columns in the dataset match the model's forward method signature. "
- f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
- "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
- )
- if version.parse(datasets.__version__) < version.parse("1.4.0"):
- dataset.set_format(
- type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
- )
- return dataset
- else:
- return dataset.remove_columns(ignored_columns)
- def _get_collator_with_removed_columns(
- self, data_collator: Callable, description: Optional[str] = None
- ) -> Callable:
- """Wrap the data collator in a callable removing unused columns."""
- if not self.args.remove_unused_columns:
- return data_collator
- self._set_signature_columns_if_needed()
- signature_columns = self._signature_columns
- remove_columns_collator = RemoveColumnsCollator(
- data_collator=data_collator,
- signature_columns=signature_columns,
- logger=logger,
- description=description,
- model_name=self.model.__class__.__name__,
- )
- return remove_columns_collator
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
- if self.train_dataset is None or not has_length(self.train_dataset):
- return None
- # Build the sampler.
- if self.args.group_by_length:
- if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
- lengths = (
- self.train_dataset[self.args.length_column_name]
- if self.args.length_column_name in self.train_dataset.column_names
- else None
- )
- else:
- lengths = None
- model_input_name = (
- self.processing_class.model_input_names[0] if self.processing_class is not None else None
- )
- return LengthGroupedSampler(
- self.args.train_batch_size * self.args.gradient_accumulation_steps,
- dataset=self.train_dataset,
- lengths=lengths,
- model_input_name=model_input_name,
- )
- else:
- return RandomSampler(self.train_dataset)
- def get_train_dataloader(self) -> DataLoader:
- """
- Returns the training [`~torch.utils.data.DataLoader`].
- Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
- training if necessary) otherwise.
- Subclass and override this method if you want to inject some custom behavior.
- """
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
- train_dataset = self._remove_unused_columns(train_dataset, description="training")
- else:
- data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
- dataloader_params = {
- "batch_size": self._train_batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_train_sampler()
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = seed_worker
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
- def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
- if eval_dataset is None or not has_length(eval_dataset):
- return None
- # Build the sampler.
- # Deprecated code
- if self.args.use_legacy_prediction_loop:
- if is_torch_xla_available():
- return SequentialDistributedSampler(
- eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
- )
- elif is_sagemaker_mp_enabled():
- return SequentialDistributedSampler(
- eval_dataset,
- num_replicas=smp.dp_size(),
- rank=smp.dp_rank(),
- batch_size=self.args.per_device_eval_batch_size,
- )
- else:
- return SequentialSampler(eval_dataset)
- if self.args.group_by_length:
- if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
- lengths = (
- eval_dataset[self.args.length_column_name]
- if self.args.length_column_name in eval_dataset.column_names
- else None
- )
- else:
- lengths = None
- model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
- return LengthGroupedSampler(
- self.args.eval_batch_size,
- dataset=eval_dataset,
- lengths=lengths,
- model_input_name=model_input_name,
- )
- if self.args.world_size <= 1:
- return SequentialSampler(eval_dataset)
- else:
- return None
- def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
- """
- Returns the evaluation [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
- If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
- """
- if eval_dataset is None and self.eval_dataset is None:
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
- # If we have persistent workers, don't do a fork bomb especially as eval datasets
- # don't change during training
- dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
- if (
- hasattr(self, "_eval_dataloaders")
- and dataloader_key in self._eval_dataloaders
- and self.args.dataloader_persistent_workers
- ):
- return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
- eval_dataset = (
- self.eval_dataset[eval_dataset]
- if isinstance(eval_dataset, str)
- else eval_dataset
- if eval_dataset is not None
- else self.eval_dataset
- )
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
- eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
- else:
- data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
- dataloader_params = {
- "batch_size": self.args.eval_batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
- if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
- # accelerator.free_memory() will destroy the references, so
- # we need to store the non-prepared version
- eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
- if self.args.dataloader_persistent_workers:
- if hasattr(self, "_eval_dataloaders"):
- self._eval_dataloaders[dataloader_key] = eval_dataloader
- else:
- self._eval_dataloaders = {dataloader_key: eval_dataloader}
- return self.accelerator.prepare(eval_dataloader)
- def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
- """
- Returns the test [`~torch.utils.data.DataLoader`].
- Subclass and override this method if you want to inject some custom behavior.
- Args:
- test_dataset (`torch.utils.data.Dataset`, *optional*):
- The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
- `model.forward()` method are automatically removed. It must implement `__len__`.
- """
- data_collator = self.data_collator
- if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
- test_dataset = self._remove_unused_columns(test_dataset, description="test")
- else:
- data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
- dataloader_params = {
- "batch_size": self.args.eval_batch_size,
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
- if not isinstance(test_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
- # We use the same batch_size as for eval.
- return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
- def create_optimizer_and_scheduler(self, num_training_steps: int):
- """
- Setup the optimizer and the learning rate scheduler.
- We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
- Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
- `create_scheduler`) in a subclass.
- """
- self.create_optimizer()
- if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
- # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
- optimizer = self.optimizer.optimizer
- else:
- optimizer = self.optimizer
- self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
- def get_decay_parameter_names(self, model) -> List[str]:
- """
- Get all parameter names that weight decay will be applied to
- Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
- apply to those modules since this function only filter out instance of nn.LayerNorm
- """
- decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
- return decay_parameters
- def create_optimizer(self):
- """
- Setup the optimizer.
- We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
- Trainer's init through `optimizers`, or subclass and override this method in a subclass.
- """
- opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
- if self.optimizer is None:
- decay_parameters = self.get_decay_parameter_names(opt_model)
- optimizer_grouped_parameters = [
- {
- "params": [
- p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
- ],
- "weight_decay": self.args.weight_decay,
- },
- {
- "params": [
- p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
- ],
- "weight_decay": 0.0,
- },
- ]
- optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
- # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
- # e.g. for GaLore optimizer.
- if "params" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("params")
- # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
- # e.g. for LOMO optimizer.
- if "model" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("model")
- # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
- # to avoid arguments conflicts.
- if "optimizer_dict" in optimizer_kwargs:
- optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
- self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
- if optimizer_cls.__name__ == "Adam8bit":
- import bitsandbytes
- manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
- skipped = 0
- for module in opt_model.modules():
- if isinstance(module, nn.Embedding):
- skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
- logger.info(f"skipped {module}: {skipped/2**20}M params")
- manager.register_module_override(module, "weight", {"optim_bits": 32})
- logger.debug(f"bitsandbytes: will optimize {module} in fp32")
- logger.info(f"skipped: {skipped/2**20}M params")
- if is_sagemaker_mp_enabled():
- self.optimizer = smp.DistributedOptimizer(self.optimizer)
- return self.optimizer
- def get_num_trainable_parameters(self):
- """
- Get the number of trainable parameters.
- """
- return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
- def get_learning_rates(self):
- """
- Returns the learning rate of each parameter from self.optimizer.
- """
- if self.optimizer is None:
- raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
- return [group["lr"] for group in self.optimizer.param_groups]
- def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None):
- """
- Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
- Args:
- param (`str` or `torch.nn.parameter.Parameter`, *optional*):
- The parameter for which optimizer group needs to be returned.
- """
- if self.optimizer is None:
- raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
- if param is not None:
- for group in self.optimizer.param_groups:
- if param in group["params"]:
- return group
- return [group["params"] for group in self.optimizer.param_groups]
- @staticmethod
- def get_optimizer_cls_and_kwargs(
- args: TrainingArguments, model: Optional[PreTrainedModel] = None
- ) -> Tuple[Any, Any]:
- """
- Returns the optimizer class and optimizer parameters based on the training arguments.
- Args:
- args (`transformers.training_args.TrainingArguments`):
- The training arguments for the training session.
- """
- # parse args.optim_args
- optim_args = {}
- if args.optim_args:
- for mapping in args.optim_args.replace(" ", "").split(","):
- key, value = mapping.split("=")
- optim_args[key] = value
- optimizer_kwargs = {"lr": args.learning_rate}
- adam_kwargs = {
- "betas": (args.adam_beta1, args.adam_beta2),
- "eps": args.adam_epsilon,
- }
- if args.optim == OptimizerNames.ADAFACTOR:
- optimizer_cls = Adafactor
- optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
- elif args.optim == OptimizerNames.ADAMW_HF:
- from .optimization import AdamW
- optimizer_cls = AdamW
- optimizer_kwargs.update(adam_kwargs)
- elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
- from torch.optim import AdamW
- optimizer_cls = AdamW
- optimizer_kwargs.update(adam_kwargs)
- if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
- optimizer_kwargs.update({"fused": True})
- elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
- try:
- from torch_xla.amp.syncfree import AdamW
- optimizer_cls = AdamW
- optimizer_kwargs.update(adam_kwargs)
- except ImportError:
- raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
- elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
- try:
- from torch_npu.optim import NpuFusedAdamW
- optimizer_cls = NpuFusedAdamW
- optimizer_kwargs.update(adam_kwargs)
- except ImportError:
- raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
- elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
- try:
- from apex.optimizers import FusedAdam
- optimizer_cls = FusedAdam
- optimizer_kwargs.update(adam_kwargs)
- except ImportError:
- raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
- elif args.optim in [
- OptimizerNames.ADAMW_BNB,
- OptimizerNames.ADAMW_8BIT,
- OptimizerNames.PAGED_ADAMW,
- OptimizerNames.PAGED_ADAMW_8BIT,
- OptimizerNames.ADEMAMIX,
- OptimizerNames.ADEMAMIX_8BIT,
- OptimizerNames.PAGED_ADEMAMIX,
- OptimizerNames.PAGED_ADEMAMIX_8BIT,
- OptimizerNames.LION,
- OptimizerNames.LION_8BIT,
- OptimizerNames.PAGED_LION,
- OptimizerNames.PAGED_LION_8BIT,
- OptimizerNames.RMSPROP_BNB,
- OptimizerNames.RMSPROP_8BIT,
- OptimizerNames.RMSPROP_32BIT,
- ]:
- try:
- from bitsandbytes.optim import AdamW, Lion, RMSprop
- is_paged = False
- optim_bits = 32
- optimizer_cls = None
- additional_optim_kwargs = adam_kwargs
- if "paged" in args.optim:
- is_paged = True
- if "8bit" in args.optim:
- optim_bits = 8
- if "adam" in args.optim:
- optimizer_cls = AdamW
- elif "lion" in args.optim:
- optimizer_cls = Lion
- additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
- elif "rmsprop" in args.optim:
- optimizer_cls = RMSprop
- # Above we pass all `adam_kwargs` to the optimizer, here
- # we only pass `optim_args` which can be passed by the user.
- additional_optim_kwargs = optim_args
- elif "ademamix" in args.optim:
- if is_bitsandbytes_available() and version.parse(
- importlib.metadata.version("bitsandbytes")
- ) < version.parse("0.44.0"):
- raise ValueError(
- "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
- "Please install `bitsandbytes` >= 0.44.0."
- )
- from bitsandbytes.optim import AdEMAMix
- optimizer_cls = AdEMAMix
- additional_optim_kwargs = {
- "betas": (
- float(optim_args.get("beta1", args.adam_beta1)),
- float(optim_args.get("beta2", args.adam_beta2)),
- float(optim_args.get("beta3", 0.9999)),
- ),
- "alpha": float(optim_args.get("alpha", 5.0)),
- "eps": float(optim_args.get("eps", args.adam_epsilon)),
- }
- if "t_alpha" in optim_args:
- additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
- if "t_beta3" in optim_args:
- additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
- bnb_kwargs = {"optim_bits": optim_bits}
- if "rmsprop" not in args.optim:
- bnb_kwargs["is_paged"] = is_paged
- optimizer_kwargs.update(additional_optim_kwargs)
- optimizer_kwargs.update(bnb_kwargs)
- except ImportError:
- raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!")
- if is_bitsandbytes_available() and version.parse(
- importlib.metadata.version("bitsandbytes")
- ) < version.parse("0.41.1"):
- logger.warning(
- "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. "
- "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers."
- )
- elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
- try:
- from torchdistx.optimizers import AnyPrecisionAdamW
- optimizer_cls = AnyPrecisionAdamW
- optimizer_kwargs.update(adam_kwargs)
- # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
- optimizer_kwargs.update(
- {
- "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
- "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
- "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
- "compensation_buffer_dtype": getattr(
- torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
- ),
- }
- )
- except ImportError:
- raise ValueError("Please install https://github.com/pytorch/torchdistx")
- elif args.optim == OptimizerNames.SGD:
- optimizer_cls = torch.optim.SGD
- elif args.optim == OptimizerNames.ADAGRAD:
- optimizer_cls = torch.optim.Adagrad
- elif args.optim == OptimizerNames.RMSPROP:
- optimizer_cls = torch.optim.RMSprop
- elif args.optim in [
- OptimizerNames.GALORE_ADAMW,
- OptimizerNames.GALORE_ADAMW_8BIT,
- OptimizerNames.GALORE_ADAFACTOR,
- OptimizerNames.GALORE_ADAMW_LAYERWISE,
- OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
- OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
- ]:
- if not is_galore_torch_available():
- raise ImportError(
- "You need to install `galore_torch` in order to use GaLore optimizers"
- " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
- )
- from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
- is_layerwise = args.optim.lower().endswith("layerwise")
- if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
- raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
- optimizer_mapping = {
- OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
- OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
- OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
- OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
- OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
- OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
- }
- optimizer_cls = optimizer_mapping[args.optim]
- if args.optim_target_modules is None:
- raise ValueError(
- "You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
- )
- if not isinstance(args.optim_target_modules, (list, str)):
- raise ValueError(
- f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
- )
- if model is None:
- raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
- logger.warning(
- "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
- )
- all_linear = (
- isinstance(args.optim_target_modules, str)
- and args.optim_target_modules.replace("_", "-") == "all-linear"
- )
- galore_params = []
- galore_params_names = []
- for module_name, module in model.named_modules():
- target_module_exists, is_regex = check_target_module_exists(
- args.optim_target_modules, module_name, return_is_regex=True
- )
- if not isinstance(module, nn.Linear):
- # Warn in case we match but it's not a linear layer
- if target_module_exists and not is_regex:
- logger.warning(
- f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
- )
- continue
- if not target_module_exists and not all_linear:
- continue
- galore_params.append(module.weight)
- galore_params_names.append(module_name + ".weight")
- if len(galore_params) == 0:
- raise ValueError(
- f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
- )
- non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
- galore_optim_kwargs = {
- "rank": int(optim_args.pop("rank", 128)),
- "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
- "scale": float(optim_args.pop("scale", 0.25)),
- "proj_type": optim_args.pop("proj_type", "std"),
- }
- # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
- param_groups = [
- {"params": non_galore_params},
- {"params": galore_params, **galore_optim_kwargs},
- ]
- if is_layerwise:
- # For layer-wise optimizers, the optimization step is done through post accumulation
- # gradient hooks. The trick is to first attach these hooks to the model parameters then
- # create a dummy optimizer that will perform no-ops in the Trainer.
- # See the original implementation or the nice implementation from @hiyouga
- # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
- if args.gradient_accumulation_steps != 1:
- raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
- optimizer_dict = {}
- for param in non_galore_params:
- param_groups = [{"params": [param]}]
- optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
- for param in galore_params:
- param_groups = [{"params": [param], **galore_optim_kwargs}]
- optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
- def optimizer_hook(param):
- if param.grad is not None:
- optimizer_dict[param].step()
- optimizer_dict[param].zero_grad()
- for param in model.parameters():
- if param.requires_grad:
- param.register_post_accumulate_grad_hook(optimizer_hook)
- optimizer_cls = LayerWiseDummyOptimizer
- optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
- optimizer_kwargs.update({"params": param_groups})
- if args.optim == OptimizerNames.GALORE_ADAFACTOR:
- optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
- elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- if not is_lomo_available():
- raise ImportError(
- "You need to install `lomo_optim` in order to use LOMO optimizers"
- " install it with `pip install lomo-optim`"
- )
- if not is_accelerate_available("0.30.0"):
- raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")
- if model is None:
- raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
- from lomo_optim import AdaLomo, Lomo
- if "ada" in args.optim:
- optimizer_cls = AdaLomo
- else:
- optimizer_cls = Lomo
- optimizer_kwargs.update({"model": model})
- elif args.optim == OptimizerNames.GROKADAMW:
- if not is_grokadamw_available():
- raise ValueError("Please install grokadamw with `pip install grokadamw`")
- from grokadamw import GrokAdamW
- optimizer_cls = GrokAdamW
- optimizer_kwargs.update(
- {
- "alpha_init": float(optim_args.get("alpha_init", 0.98)),
- "lamb": float(optim_args.get("lamb", 2.0)),
- "gamma": float(optim_args.get("gamma", 0.1)),
- "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
- "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
- }
- )
- elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
- if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
- "0.4.0"
- ):
- raise ImportError(
- "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
- "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
- )
- if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"):
- raise ImportError(
- "You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
- "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
- )
- from torchao.prototype.low_bit_optim import AdamW4bit
- optimizer_cls = AdamW4bit
- optimizer_kwargs.update(adam_kwargs)
- elif args.optim in [
- OptimizerNames.SCHEDULE_FREE_ADAMW,
- OptimizerNames.SCHEDULE_FREE_SGD,
- ]:
- if not is_schedulefree_available():
- raise ImportError(
- "You need to install `schedulefree` in order to use schedulefree optimizers"
- " install it with `pip install schedulefree`"
- )
- if not is_accelerate_available("0.30.0"):
- raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
- from schedulefree import AdamWScheduleFree, SGDScheduleFree
- additional_optim_kwargs = {}
- if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
- optimizer_cls = AdamWScheduleFree
- additional_optim_kwargs = adam_kwargs
- elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
- optimizer_cls = SGDScheduleFree
- else:
- raise ValueError("Invalid schedulefree optimizer")
- additional_optim_kwargs["weight_decay"] = args.weight_decay
- additional_optim_kwargs["warmup_steps"] = args.warmup_steps
- additional_optim_kwargs.update(
- {
- "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
- "r": float(optim_args.get("r", 0.0)),
- }
- )
- optimizer_kwargs.update(additional_optim_kwargs)
- else:
- raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
- return optimizer_cls, optimizer_kwargs
- def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
- """
- Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
- passed as an argument.
- Args:
- num_training_steps (int): The number of training steps to do.
- """
- if self.lr_scheduler is None:
- self.lr_scheduler = get_scheduler(
- self.args.lr_scheduler_type,
- optimizer=self.optimizer if optimizer is None else optimizer,
- num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
- num_training_steps=num_training_steps,
- scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
- )
- self._created_lr_scheduler = True
- return self.lr_scheduler
- def num_examples(self, dataloader: DataLoader) -> int:
- """
- Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
- dataloader.dataset does not exist or has no length, estimates as best it can
- """
- try:
- dataset = dataloader.dataset
- # Special case for IterableDatasetShard, we need to dig deeper
- if isinstance(dataset, IterableDatasetShard):
- return len(dataloader.dataset.dataset)
- return len(dataloader.dataset)
- except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
- return len(dataloader) * self.args.per_device_train_batch_size
- def num_tokens(self, train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
- """
- Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader.
- """
- train_tokens = 0
- try:
- for step, batch in enumerate(train_dl):
- tokens = batch["input_ids"].numel()
- if max_steps is not None:
- return tokens * max_steps
- train_tokens += tokens
- return train_tokens
- except KeyError:
- logger.warning("Cannot get num_tokens from dataloader")
- return train_tokens
- def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
- """HP search setup code"""
- self._trial = trial
- if self.hp_search_backend is None or trial is None:
- return
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- params = self.hp_space(trial)
- elif self.hp_search_backend == HPSearchBackend.RAY:
- params = trial
- params.pop("wandb", None)
- elif self.hp_search_backend == HPSearchBackend.SIGOPT:
- params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
- elif self.hp_search_backend == HPSearchBackend.WANDB:
- params = trial
- for key, value in params.items():
- if not hasattr(self.args, key):
- logger.warning(
- f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
- " `TrainingArguments`."
- )
- continue
- old_attr = getattr(self.args, key, None)
- # Casting value to the proper type
- if old_attr is not None:
- value = type(old_attr)(value)
- setattr(self.args, key, value)
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- logger.info(f"Trial: {trial.params}")
- if self.hp_search_backend == HPSearchBackend.SIGOPT:
- logger.info(f"SigOpt Assignments: {trial.assignments}")
- if self.hp_search_backend == HPSearchBackend.WANDB:
- logger.info(f"W&B Sweep parameters: {trial}")
- if self.is_deepspeed_enabled:
- if self.args.deepspeed is None:
- raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
- # Rebuild the deepspeed config to reflect the updated training parameters
- from accelerate.utils import DeepSpeedPlugin
- from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
- self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
- self.args.hf_deepspeed_config.trainer_config_process(self.args)
- self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
- # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
- # Simply calling `_reset_state` is enough and doesn't need a version pin.
- AcceleratorState()._reset_state()
- self.create_accelerator_and_postprocess()
- def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
- if self.hp_search_backend is None or trial is None:
- return
- metrics = metrics.copy()
- self.objective = self.compute_objective(metrics)
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- import optuna
- if not trial.study._is_multi_objective():
- trial.report(self.objective, step)
- if trial.should_prune():
- self.callback_handler.on_train_end(self.args, self.state, self.control)
- raise optuna.TrialPruned()
- elif self.hp_search_backend == HPSearchBackend.RAY:
- import ray.train
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- checkpoint = None
- if self.control.should_save:
- self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
- checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
- metrics["objective"] = self.objective
- ray.train.report(metrics, checkpoint=checkpoint)
- def _tune_save_checkpoint(self, checkpoint_dir: str):
- output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
- self.save_model(output_dir, _internal_call=True)
- if self.args.should_save:
- # Update the `TrainerControl` state to where we are currently
- self.state.stateful_callbacks["TrainerControl"] = self.control.state()
- self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
- torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- def call_model_init(self, trial=None):
- model_init_argcount = number_of_arguments(self.model_init)
- if model_init_argcount == 0:
- model = self.model_init()
- elif model_init_argcount == 1:
- model = self.model_init(trial)
- else:
- raise RuntimeError("model_init should have 0 or 1 argument.")
- if model is None:
- raise RuntimeError("model_init should not return None.")
- return model
- def torch_jit_model_eval(self, model, dataloader, training=False):
- if not training:
- if dataloader is None:
- logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
- return model
- example_batch = next(iter(dataloader))
- example_batch = self._prepare_inputs(example_batch)
- try:
- jit_model = copy.copy(model)
- jit_model.eval()
- original_forward = jit_model.__dict__.pop("_original_forward", None)
- # remove mixed precision hooks from the model
- if original_forward:
- jit_model.forward = original_forward
- with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
- if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
- if isinstance(example_batch, dict):
- jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
- else:
- jit_model = torch.jit.trace(
- jit_model,
- example_kwarg_inputs={key: example_batch[key] for key in example_batch},
- strict=False,
- )
- else:
- jit_inputs = []
- for key in example_batch:
- example_tensor = torch.ones_like(example_batch[key])
- jit_inputs.append(example_tensor)
- jit_inputs = tuple(jit_inputs)
- jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
- jit_model = torch.jit.freeze(jit_model)
- with torch.no_grad():
- jit_model(**example_batch)
- jit_model(**example_batch)
- model = jit_model
- self.use_cpu_amp = False
- except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
- logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
- return model
- def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
- if not is_ipex_available():
- raise ImportError(
- "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
- " to https://github.com/intel/intel-extension-for-pytorch."
- )
- import intel_extension_for_pytorch as ipex
- if not training:
- model.eval()
- dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
- # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
- model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
- else:
- if not model.training:
- model.train()
- model, self.optimizer = ipex.optimize(
- model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
- )
- return model
- def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
- attributes_map = {
- "logging_steps": "logging_steps",
- "eval_steps": "eval_steps",
- "save_steps": "save_steps",
- }
- has_warning = False
- warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
- for arg_attr, state_attr in attributes_map.items():
- arg_value = getattr(training_args, arg_attr, None)
- state_value = getattr(trainer_state, state_attr, None)
- if arg_value is not None and state_value is not None and arg_value != state_value:
- warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
- has_warning = True
- # train bs is special as we need to account for multi-GPU
- train_bs_args = training_args.per_device_train_batch_size
- train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
- if train_bs_args != train_bs_state:
- warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
- has_warning = True
- if has_warning:
- logger.warning_once(warning_str)
- def _wrap_model(self, model, training=True, dataloader=None):
- if self.args.use_ipex:
- dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
- model = self.ipex_optimize_model(model, training, dtype=dtype)
- if is_sagemaker_mp_enabled():
- # Wrapping the base model twice in a DistributedModel will raise an error.
- if isinstance(self.model_wrapped, smp.model.DistributedModel):
- return self.model_wrapped
- return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
- # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
- if self.accelerator.unwrap_model(model) is not model:
- return model
- # Mixed precision training with apex (torch < 1.6)
- if self.use_apex and training:
- model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
- # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
- if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
- model = nn.DataParallel(model)
- if self.args.jit_mode_eval:
- start_time = time.time()
- model = self.torch_jit_model_eval(model, dataloader, training)
- self.jit_compilation_time = round(time.time() - start_time, 4)
- # Note: in torch.distributed mode, there's no point in wrapping the model
- # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
- if not training:
- return model
- # Distributed training (should be after apex fp16 initialization)
- # Distributed training using PyTorch FSDP
- if self.is_fsdp_xla_enabled:
- try:
- from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
- from torch_xla.distributed.fsdp import checkpoint_module
- from torch_xla.distributed.fsdp.wrap import (
- size_based_auto_wrap_policy,
- transformer_auto_wrap_policy,
- )
- if self.is_fsdp_xla_v2_enabled:
- from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
- SpmdFullyShardedDataParallel as FSDPv2,
- )
- except ImportError:
- raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
- auto_wrap_policy = None
- auto_wrapper_callable = None
- default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
- fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
- "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
- )
- if self.args.fsdp_config["min_num_params"] > 0:
- auto_wrap_policy = functools.partial(
- size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
- )
- elif fsdp_transformer_layer_cls_to_wrap is not None:
- transformer_cls_to_wrap = set()
- for layer_class in fsdp_transformer_layer_cls_to_wrap:
- transformer_cls = get_module_class_from_name(model, layer_class)
- if transformer_cls is None:
- raise Exception("Could not find the transformer layer class to wrap in the model.")
- else:
- transformer_cls_to_wrap.add(transformer_cls)
- auto_wrap_policy = functools.partial(
- transformer_auto_wrap_policy,
- # Transformer layer class to wrap
- transformer_layer_cls=transformer_cls_to_wrap,
- )
- fsdp_kwargs = self.args.xla_fsdp_config
- if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
- if model.config.use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- model.config.use_cache = False
- # Apply gradient checkpointing to auto-wrapped sub-modules if specified
- def auto_wrapper_callable(m, *args, **kwargs):
- target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
- return target_cls(checkpoint_module(m), *args, **kwargs)
- # Wrap the base model with an outer FSDP wrapper
- if self.is_fsdp_xla_v2_enabled:
- def shard_output(output, mesh):
- from .modeling_outputs import CausalLMOutputWithPast
- real_output = None
- if isinstance(output, torch.Tensor):
- real_output = output
- elif isinstance(output, tuple):
- real_output = output[0]
- elif isinstance(output, CausalLMOutputWithPast):
- real_output = output.logits
- if real_output is None:
- raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
- xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
- self.model = model = FSDPv2(
- model,
- shard_output=shard_output,
- auto_wrap_policy=auto_wrap_policy,
- auto_wrapper_callable=auto_wrapper_callable,
- )
- else:
- self.model = model = FSDP(
- model,
- auto_wrap_policy=auto_wrap_policy,
- auto_wrapper_callable=auto_wrapper_callable,
- **fsdp_kwargs,
- )
- # Patch `xm.optimizer_step` should not reduce gradients in this case,
- # as FSDP does not need gradient reduction over sharded parameters.
- def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
- loss = optimizer.step(**optimizer_args)
- if barrier:
- xm.mark_step()
- return loss
- xm.optimizer_step = patched_optimizer_step
- elif is_sagemaker_dp_enabled():
- model = nn.parallel.DistributedDataParallel(
- model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
- )
- elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- if is_torch_neuroncore_available():
- return model
- kwargs = {}
- if self.args.ddp_find_unused_parameters is not None:
- kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
- elif isinstance(model, PreTrainedModel):
- # find_unused_parameters breaks checkpointing as per
- # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
- kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
- else:
- kwargs["find_unused_parameters"] = True
- if self.args.ddp_bucket_cap_mb is not None:
- kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
- if self.args.ddp_broadcast_buffers is not None:
- kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
- self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
- return model
- def train(
- self,
- resume_from_checkpoint: Optional[Union[str, bool]] = None,
- trial: Union["optuna.Trial", Dict[str, Any]] = None,
- ignore_keys_for_eval: Optional[List[str]] = None,
- **kwargs,
- ):
- """
- Main training entry point.
- Args:
- resume_from_checkpoint (`str` or `bool`, *optional*):
- If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
- `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
- of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
- trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
- The trial run or the hyperparameter dictionary for hyperparameter search.
- ignore_keys_for_eval (`List[str]`, *optional*)
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions for evaluation during the training.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional keyword arguments used to hide deprecated arguments
- """
- if resume_from_checkpoint is False:
- resume_from_checkpoint = None
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- args = self.args
- self.is_in_train = True
- # Attach NEFTune hooks if necessary
- if self.neftune_noise_alpha is not None:
- self.model = self._activate_neftune(self.model)
- # do_train is not a reliable argument, as it might not be set and .train() still called, so
- # the following is a workaround:
- if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel:
- self._move_model_to_device(self.model, args.device)
- if "model_path" in kwargs:
- resume_from_checkpoint = kwargs.pop("model_path")
- warnings.warn(
- "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
- "instead.",
- FutureWarning,
- )
- if len(kwargs) > 0:
- raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
- # This might change the seed so needs to run first.
- self._hp_search_setup(trial)
- self._train_batch_size = self.args.train_batch_size
- # Model re-init
- model_reloaded = False
- if self.model_init is not None:
- # Seed must be set before instantiating the model when using model_init.
- enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
- self.model = self.call_model_init(trial)
- model_reloaded = True
- # Reinitializes optimizer and scheduler
- self.optimizer, self.lr_scheduler = None, None
- # Load potential model checkpoint
- if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
- resume_from_checkpoint = get_last_checkpoint(args.output_dir)
- if resume_from_checkpoint is None:
- raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
- if resume_from_checkpoint is not None:
- if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
- self._load_from_checkpoint(resume_from_checkpoint)
- # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
- state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
- if state.train_batch_size is not None:
- self._train_batch_size = state.train_batch_size
- # If model was re-initialized, put it on the right device and update self.model_wrapped
- if model_reloaded:
- if self.place_model_on_device:
- self._move_model_to_device(self.model, args.device)
- self.model_wrapped = self.model
- inner_training_loop = find_executable_batch_size(
- self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
- )
- if args.push_to_hub:
- try:
- # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
- hf_hub_utils.disable_progress_bars()
- return inner_training_loop(
- args=args,
- resume_from_checkpoint=resume_from_checkpoint,
- trial=trial,
- ignore_keys_for_eval=ignore_keys_for_eval,
- )
- finally:
- hf_hub_utils.enable_progress_bars()
- else:
- return inner_training_loop(
- args=args,
- resume_from_checkpoint=resume_from_checkpoint,
- trial=trial,
- ignore_keys_for_eval=ignore_keys_for_eval,
- )
- def _inner_training_loop(
- self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
- ):
- self.accelerator.free_memory()
- self._train_batch_size = batch_size
- if self.args.auto_find_batch_size:
- if self.state.train_batch_size != self._train_batch_size:
- from accelerate.utils import release_memory
- (self.model_wrapped,) = release_memory(self.model_wrapped)
- self.model_wrapped = self.model
- # Check for DeepSpeed *after* the intial pass and modify the config
- if self.is_deepspeed_enabled:
- # Temporarily unset `self.args.train_batch_size`
- original_bs = self.args.per_device_train_batch_size
- self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
- self.propagate_args_to_deepspeed(True)
- self.args.per_device_train_batch_size = original_bs
- self.state.train_batch_size = self._train_batch_size
- logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
- # Data loader and number of training steps
- train_dataloader = self.get_train_dataloader()
- if self.is_fsdp_xla_v2_enabled:
- train_dataloader = tpu_spmd_dataloader(train_dataloader)
- # Setting up training control variables:
- # number of training epochs: num_train_epochs
- # number of training steps per epoch: num_update_steps_per_epoch
- # total number of training steps to execute: max_steps
- total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
- len_dataloader = None
- num_train_tokens = None
- if has_length(train_dataloader):
- len_dataloader = len(train_dataloader)
- num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
- num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
- num_examples = self.num_examples(train_dataloader)
- if args.max_steps > 0:
- max_steps = args.max_steps
- num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
- args.max_steps % num_update_steps_per_epoch > 0
- )
- # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
- # the best we can do.
- num_train_samples = args.max_steps * total_train_batch_size
- if args.include_tokens_per_second:
- num_train_tokens = (
- self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
- )
- else:
- max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
- num_train_epochs = math.ceil(args.num_train_epochs)
- num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
- if args.include_tokens_per_second:
- num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
- elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
- max_steps = args.max_steps
- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
- num_train_epochs = sys.maxsize
- num_update_steps_per_epoch = max_steps
- num_examples = total_train_batch_size * args.max_steps
- num_train_samples = args.max_steps * total_train_batch_size
- if args.include_tokens_per_second:
- num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
- else:
- raise ValueError(
- "args.max_steps must be set to a positive value if dataloader does not have a length, was"
- f" {args.max_steps}"
- )
- if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
- if self.args.n_gpu > 1:
- # nn.DataParallel(model) replicates the model, creating new variables and module
- # references registered here no longer work on other gpus, breaking the module
- raise ValueError(
- "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
- " (torchrun or torch.distributed.launch (deprecated))."
- )
- else:
- debug_overflow = DebugUnderflowOverflow(self.model) # noqa
- delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
- # We need to reset the scheduler, as its parameters may be different on subsequent calls
- if self._created_lr_scheduler:
- self.lr_scheduler = None
- self._created_lr_scheduler = False
- if self.is_deepspeed_enabled:
- self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
- if not delay_optimizer_creation:
- self.create_optimizer_and_scheduler(num_training_steps=max_steps)
- self.state = TrainerState(
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ]
- )
- self.state.is_hyper_param_search = trial is not None
- self.state.train_batch_size = self._train_batch_size
- # Compute absolute values for logging, eval, and save if given as ratio
- if args.logging_steps is not None:
- if args.logging_steps < 1:
- self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
- else:
- self.state.logging_steps = args.logging_steps
- if args.eval_steps is not None:
- if args.eval_steps < 1:
- self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
- else:
- self.state.eval_steps = args.eval_steps
- if args.save_steps is not None:
- if args.save_steps < 1:
- self.state.save_steps = math.ceil(max_steps * args.save_steps)
- else:
- self.state.save_steps = args.save_steps
- # Activate gradient checkpointing if needed
- if args.gradient_checkpointing:
- self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
- model = self._wrap_model(self.model_wrapped)
- # as the model is wrapped, don't use `accelerator.prepare`
- # this is for unhandled cases such as
- # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
- use_accelerator_prepare = True if model is self.model else False
- if delay_optimizer_creation:
- if use_accelerator_prepare:
- self._fsdp_qlora_plugin_updates()
- self.model = self.accelerator.prepare(self.model)
- self.create_optimizer_and_scheduler(num_training_steps=max_steps)
- # prepare using `accelerator` prepare
- if use_accelerator_prepare:
- self.model.train()
- if hasattr(self.lr_scheduler, "step"):
- if self.use_apex:
- model = self.accelerator.prepare(self.model)
- else:
- model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
- else:
- # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
- model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
- self.model, self.optimizer, self.lr_scheduler
- )
- elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- # In this case we are in DDP + LOMO, which should be supported
- self.optimizer = self.accelerator.prepare(self.optimizer)
- if self.is_fsdp_enabled:
- self.model = self.model_wrapped = model
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
- if model is not self.model:
- self.model_wrapped = model
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model_wrapped
- # ckpt loading
- if resume_from_checkpoint is not None:
- if self.is_deepspeed_enabled:
- deepspeed_load_checkpoint(
- self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
- )
- elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
- self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
- # Check if saved optimizer or scheduler states exist
- self._load_optimizer_and_scheduler(resume_from_checkpoint)
- # important: at this point:
- # self.model is the Transformers Model
- # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
- # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
- # Train!
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {num_examples:,}")
- logger.info(f" Num Epochs = {num_train_epochs:,}")
- logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
- if self.args.per_device_train_batch_size != self._train_batch_size:
- logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {max_steps:,}")
- logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
- self.state.epoch = 0
- start_time = time.time()
- epochs_trained = 0
- steps_trained_in_current_epoch = 0
- steps_trained_progress_bar = None
- # Check if continuing training from a checkpoint
- if resume_from_checkpoint is not None and os.path.isfile(
- os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
- ):
- self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
- self.compare_trainer_and_checkpoint_args(self.args, self.state)
- self._load_callback_state()
- epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
- if not args.ignore_data_skip:
- steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
- steps_trained_in_current_epoch *= args.gradient_accumulation_steps
- else:
- steps_trained_in_current_epoch = 0
- logger.info(" Continuing training from checkpoint, will skip to saved global_step")
- logger.info(f" Continuing training from epoch {epochs_trained}")
- logger.info(f" Continuing training from global step {self.state.global_step}")
- if not args.ignore_data_skip:
- logger.info(
- f" Will skip the first {epochs_trained} epochs then the first"
- f" {steps_trained_in_current_epoch} batches in the first epoch."
- )
- # Update the references
- self.callback_handler.model = self.model
- self.callback_handler.optimizer = self.optimizer
- self.callback_handler.lr_scheduler = self.lr_scheduler
- self.callback_handler.train_dataloader = train_dataloader
- if self.hp_name is not None and self._trial is not None:
- # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
- # parameter to Train when using DDP.
- self.state.trial_name = self.hp_name(self._trial)
- if trial is not None:
- assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
- self.state.trial_params = hp_params(assignments)
- else:
- self.state.trial_params = None
- # This should be the same if the state has been saved but in case the training arguments changed, it's safer
- # to set this after the load.
- self.state.max_steps = max_steps
- self.state.num_train_epochs = num_train_epochs
- self.state.is_local_process_zero = self.is_local_process_zero()
- self.state.is_world_process_zero = self.is_world_process_zero()
- # tr_loss is a tensor to avoid synchronization of TPUs through .item()
- tr_loss = torch.tensor(0.0).to(args.device)
- # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
- self._total_loss_scalar = 0.0
- self._globalstep_last_logged = self.state.global_step
- model.zero_grad()
- grad_norm: Optional[float] = None
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
- if args.eval_on_start:
- self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
- total_batched_samples = 0
- for epoch in range(epochs_trained, num_train_epochs):
- epoch_dataloader = train_dataloader
- if hasattr(epoch_dataloader, "set_epoch"):
- epoch_dataloader.set_epoch(epoch)
- # Reset the past mems state at the beginning of each epoch if necessary.
- if args.past_index >= 0:
- self._past = None
- steps_in_epoch = (
- len(epoch_dataloader)
- if len_dataloader is not None
- else args.max_steps * args.gradient_accumulation_steps
- )
- self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
- if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
- self._load_rng_state(resume_from_checkpoint)
- rng_to_sync = False
- steps_skipped = 0
- if steps_trained_in_current_epoch > 0:
- epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
- steps_skipped = steps_trained_in_current_epoch
- steps_trained_in_current_epoch = 0
- rng_to_sync = True
- step = -1
- epoch_iterator = iter(epoch_dataloader)
- # We chunkify the epoch iterator into gradient accumulation steps `n` batches
- remainder = num_examples % args.gradient_accumulation_steps
- num_items_in_batch = None
- if remainder == 0:
- remainder = args.gradient_accumulation_steps
- update_step = -1
- total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
- for _ in range(total_updates):
- update_step += 1
- num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
- batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
- for i, inputs in enumerate(batch_samples):
- step += 1
- total_batched_samples += 1
- is_last_step_and_steps_less_than_grad_acc = (
- steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
- )
- do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
- total_batched_samples % args.gradient_accumulation_steps == 0
- )
- # Since we perform prefetching, we need to manually set sync_gradients
- if not do_sync_step:
- self.accelerator.gradient_state._set_sync_gradients(False)
- else:
- self.accelerator.gradient_state._set_sync_gradients(True)
- if self.args.include_num_input_tokens_seen:
- main_input_name = getattr(self.model, "main_input_name", "input_ids")
- if main_input_name not in inputs:
- logger.warning(
- "Tried to track the number of tokens seen, however the current model is "
- "not configured properly to know what item is the input. To fix this, add "
- "a `main_input_name` attribute to the model class you are using."
- )
- else:
- input_tokens = inputs[main_input_name].numel()
- input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
- self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).cpu().item()
- if rng_to_sync:
- self._load_rng_state(resume_from_checkpoint)
- rng_to_sync = False
- # Skip past any already trained steps if resuming training
- if steps_trained_in_current_epoch > 0:
- steps_trained_in_current_epoch -= 1
- if steps_trained_progress_bar is not None:
- steps_trained_progress_bar.update(1)
- if steps_trained_in_current_epoch == 0:
- self._load_rng_state(resume_from_checkpoint)
- continue
- elif steps_trained_progress_bar is not None:
- steps_trained_progress_bar.close()
- steps_trained_progress_bar = None
- if step % args.gradient_accumulation_steps == 0:
- self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
- # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
- context = (
- functools.partial(self.accelerator.no_sync, model=model)
- if i != len(batch_samples) - 1
- else contextlib.nullcontext
- )
- with context():
- tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
- if (
- args.logging_nan_inf_filter
- and not is_torch_xla_available()
- and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
- ):
- # if loss is nan or inf simply add the average of previous logged losses
- tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
- else:
- if tr_loss.device != tr_loss_step.device:
- raise ValueError(
- f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
- )
- tr_loss = tr_loss + tr_loss_step
- self.current_flos += float(self.floating_point_ops(inputs))
- if do_sync_step:
- # Since we perform prefetching, we need to manually set sync_gradients to True
- self.accelerator.gradient_state._set_sync_gradients(True)
- # Gradient clipping
- if args.max_grad_norm is not None and args.max_grad_norm > 0:
- # deepspeed does its own clipping
- if is_sagemaker_mp_enabled() and args.fp16:
- _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
- elif self.use_apex:
- # Revert to normal clipping otherwise, handling Apex or full precision
- _grad_norm = nn.utils.clip_grad_norm_(
- amp.master_params(self.optimizer),
- args.max_grad_norm,
- )
- else:
- _grad_norm = self.accelerator.clip_grad_norm_(
- model.parameters(),
- args.max_grad_norm,
- )
- if (
- is_accelerate_available()
- and self.accelerator.distributed_type == DistributedType.DEEPSPEED
- ):
- grad_norm = model.get_global_grad_norm()
- # In some cases the grad norm may not return a float
- if hasattr(grad_norm, "item"):
- grad_norm = grad_norm.item()
- else:
- grad_norm = _grad_norm
- self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
- self.optimizer.step()
- self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
- optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
- if optimizer_was_run:
- # Delay optimizer scheduling until metrics are generated
- if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
- self.lr_scheduler.step()
- model.zero_grad()
- self.state.global_step += 1
- self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
- self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
- else:
- self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
- # PyTorch/XLA relies on the data loader to insert the mark_step for
- # each step. Since we are breaking the loop early, we need to manually
- # insert the mark_step here.
- if self.control.should_epoch_stop or self.control.should_training_stop:
- if is_torch_xla_available():
- xm.mark_step()
- break
- # We also need to break out of the nested loop
- if self.control.should_epoch_stop or self.control.should_training_stop:
- if is_torch_xla_available():
- xm.mark_step()
- break
- if step < 0:
- logger.warning(
- "There seems not to be a single sample in your epoch_iterator, stopping training at step"
- f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
- f" num_steps ({max_steps}) higher than the number of available samples."
- )
- self.control.should_training_stop = True
- self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
- self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
- if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
- if is_torch_xla_available():
- # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
- xm.master_print(met.metrics_report())
- else:
- logger.warning(
- "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
- "configured. Check your training configuration if this is unexpected."
- )
- if self.control.should_training_stop:
- break
- if args.past_index and hasattr(self, "_past"):
- # Clean the state at the end of training
- delattr(self, "_past")
- logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
- if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
- # Wait for everyone to get here so we are sure the model has been saved by process 0.
- if is_torch_xla_available():
- xm.rendezvous("load_best_model_at_end")
- elif args.parallel_mode == ParallelMode.DISTRIBUTED:
- dist.barrier()
- elif is_sagemaker_mp_enabled():
- smp.barrier()
- self._load_best_model()
- # add remaining tr_loss
- self._total_loss_scalar += tr_loss.item()
- effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
- train_loss = self._total_loss_scalar / effective_global_step
- metrics = speed_metrics(
- "train",
- start_time,
- num_samples=num_train_samples,
- num_steps=self.state.max_steps,
- num_tokens=num_train_tokens,
- )
- self.store_flos()
- metrics["total_flos"] = self.state.total_flos
- metrics["train_loss"] = train_loss
- self.is_in_train = False
- self._memory_tracker.stop_and_update_metrics(metrics)
- self.log(metrics)
- run_dir = self._get_output_dir(trial)
- checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
- # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
- if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
- for checkpoint in checkpoints_sorted:
- if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
- logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
- shutil.rmtree(checkpoint, ignore_errors=True)
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
- # Wait for the checkpoint to be uploaded.
- self._finish_current_push()
- # After training we make sure to retrieve back the original forward pass method
- # for the embedding layer by removing the forward post hook.
- if self.neftune_noise_alpha is not None:
- self._deactivate_neftune(self.model)
- return TrainOutput(self.state.global_step, train_loss, metrics)
- def _get_output_dir(self, trial):
- if self.hp_search_backend is not None and trial is not None:
- if self.hp_search_backend == HPSearchBackend.OPTUNA:
- run_id = trial.number
- elif self.hp_search_backend == HPSearchBackend.RAY:
- import ray.train
- run_id = ray.train.get_context().get_trial_id()
- elif self.hp_search_backend == HPSearchBackend.SIGOPT:
- run_id = trial.id
- elif self.hp_search_backend == HPSearchBackend.WANDB:
- import wandb
- run_id = wandb.run.id
- run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
- run_dir = os.path.join(self.args.output_dir, run_name)
- else:
- run_dir = self.args.output_dir
- return run_dir
- def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
- if model is None:
- model = self.model
- config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
- adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
- adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
- weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
- weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
- safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
- safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
- is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
- # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
- any(
- FSDP_MODEL_NAME in folder_name
- for folder_name in os.listdir(resume_from_checkpoint)
- if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
- )
- # this checks the FSDP state dict when `FULL_STATE_DICT` is used
- or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
- )
- # if multiple adapters exist, they get saved in sub directories
- adapter_subdirs = (
- [
- folder_name
- for folder_name in os.listdir(resume_from_checkpoint)
- if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
- and (
- os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
- or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
- )
- ]
- if os.path.isdir(resume_from_checkpoint)
- else []
- )
- if is_fsdp_ckpt and not self.is_fsdp_enabled:
- raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
- if not (
- any(
- os.path.isfile(f)
- for f in [
- weights_file,
- safe_weights_file,
- weights_index_file,
- safe_weights_index_file,
- adapter_weights_file,
- adapter_safe_weights_file,
- ]
- )
- or is_fsdp_ckpt
- or adapter_subdirs
- ):
- raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
- logger.info(f"Loading model from {resume_from_checkpoint}.")
- if os.path.isfile(config_file):
- config = PretrainedConfig.from_json_file(config_file)
- checkpoint_version = config.transformers_version
- if checkpoint_version is not None and checkpoint_version != __version__:
- logger.warning(
- f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
- f"Transformers but your current version is {__version__}. This is not recommended and could "
- "yield to errors or unwanted behaviors."
- )
- if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
- weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
- # If the model is on the GPU, it still works!
- if is_sagemaker_mp_enabled():
- if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
- # If the 'user_content.pt' file exists, load with the new smp api.
- # Checkpoint must have been saved with the new smp api.
- smp.resume_from_checkpoint(
- path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
- )
- else:
- # If the 'user_content.pt' file does NOT exist, load with the old smp api.
- # Checkpoint must have been saved with the old smp api.
- if hasattr(self.args, "fp16") and self.args.fp16 is True:
- logger.warning(
- "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
- )
- state_dict = torch.load(
- weights_file,
- map_location="cpu",
- **weights_only_kwarg,
- )
- # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
- state_dict["_smp_is_partial"] = False
- load_result = model.load_state_dict(state_dict, strict=True)
- # release memory
- del state_dict
- elif self.is_fsdp_enabled:
- load_fsdp_model(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- model,
- resume_from_checkpoint,
- **_get_fsdp_ckpt_kwargs(),
- )
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- if self.args.save_safetensors and os.path.isfile(safe_weights_file):
- state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
- else:
- state_dict = torch.load(
- weights_file,
- map_location="cpu",
- **weights_only_kwarg,
- )
- # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
- # which takes *args instead of **kwargs
- load_result = model.load_state_dict(state_dict, False)
- # release memory
- del state_dict
- self._issue_warnings_after_load(load_result)
- # Load adapters following PR # 24096
- elif _is_peft_model(model):
- # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
- # TODO: in the future support only specific min PEFT versions
- if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
- model, "load_adapter"
- ):
- if os.path.exists(resume_from_checkpoint):
- # For BC for older PEFT versions
- if hasattr(model, "active_adapters"):
- active_adapters = model.active_adapters
- if len(active_adapters) > 1:
- logger.warning("Multiple active adapters detected will only consider the first adapter")
- active_adapter = active_adapters[0]
- else:
- active_adapter = model.active_adapter
- if adapter_subdirs:
- for subdir_name in adapter_subdirs:
- peft_id = os.path.join(resume_from_checkpoint, subdir_name)
- model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
- model.set_adapter(active_adapter)
- else:
- model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
- else:
- logger.warning(
- "The intermediate checkpoints of PEFT may not be saved correctly, "
- f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
- "Check some examples here: https://github.com/huggingface/peft/issues/96"
- )
- else:
- logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
- else:
- # We load the sharded checkpoint
- load_result = load_sharded_checkpoint(
- model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
- )
- if not is_sagemaker_mp_enabled():
- self._issue_warnings_after_load(load_result)
- def _load_best_model(self):
- logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
- best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
- best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
- best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
- best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
- model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
- if self.is_deepspeed_enabled:
- deepspeed_load_checkpoint(
- self.model_wrapped,
- self.state.best_model_checkpoint,
- load_module_strict=not _is_peft_model(self.model),
- )
- elif self.is_fsdp_enabled:
- load_result = load_fsdp_model(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- model,
- self.state.best_model_checkpoint,
- **_get_fsdp_ckpt_kwargs(),
- )
- elif (
- os.path.exists(best_model_path)
- or os.path.exists(best_safe_model_path)
- or os.path.exists(best_adapter_model_path)
- or os.path.exists(best_safe_adapter_model_path)
- ):
- has_been_loaded = True
- weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
- if is_sagemaker_mp_enabled():
- if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
- # If the 'user_content.pt' file exists, load with the new smp api.
- # Checkpoint must have been saved with the new smp api.
- smp.resume_from_checkpoint(
- path=self.state.best_model_checkpoint,
- tag=WEIGHTS_NAME,
- partial=False,
- load_optimizer=False,
- )
- else:
- # If the 'user_content.pt' file does NOT exist, load with the old smp api.
- # Checkpoint must have been saved with the old smp api.
- if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
- state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
- else:
- state_dict = torch.load(
- best_model_path,
- map_location="cpu",
- **weights_only_kwarg,
- )
- state_dict["_smp_is_partial"] = False
- load_result = model.load_state_dict(state_dict, strict=True)
- else:
- if _is_peft_model(model):
- # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
- # TODO: in the future support only specific min PEFT versions
- if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
- model, "load_adapter"
- ):
- # For BC for older PEFT versions
- if hasattr(model, "active_adapters"):
- active_adapter = model.active_adapters[0]
- if len(model.active_adapters) > 1:
- logger.warning("Detected multiple active adapters, will only consider the first one")
- else:
- active_adapter = model.active_adapter
- if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
- model.load_adapter(self.state.best_model_checkpoint, active_adapter)
- # Load_adapter has no return value present, modify it when appropriate.
- from torch.nn.modules.module import _IncompatibleKeys
- load_result = _IncompatibleKeys([], [])
- else:
- logger.warning(
- "The intermediate checkpoints of PEFT may not be saved correctly, "
- f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
- "Check some examples here: https://github.com/huggingface/peft/issues/96"
- )
- has_been_loaded = False
- else:
- logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
- has_been_loaded = False
- else:
- # We load the model state dict on the CPU to avoid an OOM error.
- if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
- state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
- else:
- state_dict = torch.load(
- best_model_path,
- map_location="cpu",
- **weights_only_kwarg,
- )
- # If the model is on the GPU, it still works!
- # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
- # which takes *args instead of **kwargs
- load_result = model.load_state_dict(state_dict, False)
- if not is_sagemaker_mp_enabled() and has_been_loaded:
- self._issue_warnings_after_load(load_result)
- elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
- os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
- ):
- load_result = load_sharded_checkpoint(
- model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
- )
- if not is_sagemaker_mp_enabled():
- self._issue_warnings_after_load(load_result)
- else:
- logger.warning(
- f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
- "on multiple nodes, you should activate `--save_on_each_node`."
- )
- def _issue_warnings_after_load(self, load_result):
- if len(load_result.missing_keys) != 0:
- if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
- self.model._keys_to_ignore_on_save
- ):
- self.model.tie_weights()
- else:
- logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
- if len(load_result.unexpected_keys) != 0:
- logger.warning(
- f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
- )
- def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
- metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
- self._report_to_hp_search(trial, self.state.global_step, metrics)
- # Run delayed LR scheduler now that metrics are populated
- if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
- metric_to_check = self.args.metric_for_best_model
- if not metric_to_check.startswith("eval_"):
- metric_to_check = f"eval_{metric_to_check}"
- try:
- self.lr_scheduler.step(metrics[metric_to_check])
- except KeyError as exc:
- raise KeyError(
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
- f"which is not found in the evaluation metrics. "
- f"The available evaluation metrics are: {list(metrics.keys())}. "
- f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
- f"consider changing the `metric_for_best_model` via the TrainingArguments."
- ) from exc
- return metrics
- def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
- if is_torch_xla_available():
- xm.mark_step()
- logs: Dict[str, float] = {}
- # all_gather + mean() to get average loss over all processes
- tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
- # reset tr_loss to zero
- tr_loss -= tr_loss
- logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
- if grad_norm is not None:
- logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
- logs["learning_rate"] = self._get_learning_rate()
- self._total_loss_scalar += tr_loss_scalar
- self._globalstep_last_logged = self.state.global_step
- self.store_flos()
- self.log(logs)
- metrics = None
- if self.control.should_evaluate:
- metrics = self._evaluate(trial, ignore_keys_for_eval)
- if self.control.should_save:
- self._save_checkpoint(model, trial, metrics=metrics)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
- def _load_rng_state(self, checkpoint):
- # Load RNG states from `checkpoint`
- if checkpoint is None:
- return
- if self.args.world_size > 1:
- process_index = self.args.process_index
- rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
- if not os.path.isfile(rng_file):
- logger.info(
- f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
- "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
- )
- return
- else:
- rng_file = os.path.join(checkpoint, "rng_state.pth")
- if not os.path.isfile(rng_file):
- logger.info(
- "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
- "fashion, reproducibility is not guaranteed."
- )
- return
- checkpoint_rng_state = torch.load(rng_file)
- random.setstate(checkpoint_rng_state["python"])
- np.random.set_state(checkpoint_rng_state["numpy"])
- torch.random.set_rng_state(checkpoint_rng_state["cpu"])
- if torch.cuda.is_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
- else:
- try:
- torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
- except Exception as e:
- logger.info(
- f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
- "\nThis won't yield the same results as if the training had not been interrupted."
- )
- if is_torch_xla_available():
- xm.set_rng_state(checkpoint_rng_state["xla"])
- if is_torch_npu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
- else:
- try:
- torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
- except Exception as e:
- logger.info(
- f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
- "\nThis won't yield the same results as if the training had not been interrupted."
- )
- if is_torch_mlu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"])
- else:
- try:
- torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"])
- except Exception as e:
- logger.info(
- f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}"
- "\nThis won't yield the same results as if the training had not been interrupted."
- )
- if is_torch_musa_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- torch.musa.set_rng_state_all(checkpoint_rng_state["musa"])
- else:
- try:
- torch.musa.set_rng_state(checkpoint_rng_state["musa"])
- except Exception as e:
- logger.info(
- f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}"
- "\nThis won't yield the same results as if the training had not been interrupted."
- )
- def _save_checkpoint(self, model, trial, metrics=None):
- # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
- # want to save except FullyShardedDDP.
- # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
- # Save model checkpoint
- checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
- if self.hp_search_backend is None and trial is None:
- self.store_flos()
- run_dir = self._get_output_dir(trial=trial)
- output_dir = os.path.join(run_dir, checkpoint_folder)
- self.save_model(output_dir, _internal_call=True)
- if not self.args.save_only_model:
- # Save optimizer and scheduler
- self._save_optimizer_and_scheduler(output_dir)
- # Save RNG state
- self._save_rng_state(output_dir)
- # Determine the new best metric / best model checkpoint
- if metrics is not None and self.args.metric_for_best_model is not None:
- metric_to_check = self.args.metric_for_best_model
- if not metric_to_check.startswith("eval_"):
- metric_to_check = f"eval_{metric_to_check}"
- try:
- metric_value = metrics[metric_to_check]
- except KeyError as exc:
- raise KeyError(
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
- f"which is not found in the evaluation metrics. "
- f"The available evaluation metrics are: {list(metrics.keys())}. "
- f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
- f"consider changing the `metric_for_best_model` via the TrainingArguments."
- ) from exc
- operator = np.greater if self.args.greater_is_better else np.less
- if (
- self.state.best_metric is None
- or self.state.best_model_checkpoint is None
- or operator(metric_value, self.state.best_metric)
- ):
- self.state.best_metric = metric_value
- self.state.best_model_checkpoint = output_dir
- # Save the Trainer state
- if self.args.should_save:
- # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
- for cb in [
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ]:
- cb_name = cb.__class__.__name__
- cb_state = cb.state()
- if isinstance(self.state.stateful_callbacks[cb_name], list):
- self.state.stateful_callbacks[cb_name].append(cb_state)
- else:
- self.state.stateful_callbacks[cb_name] = cb_state
- self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
- if self.args.push_to_hub:
- self._push_from_checkpoint(output_dir)
- # Maybe delete some older checkpoints.
- if self.args.should_save:
- # Solely rely on numerical checkpoint id for rotation.
- # mtime is not reliable especially on some fuse fs in cloud environments.
- self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
- def _save_rng_state(self, output_dir):
- # Save RNG state in non-distributed training
- rng_states = {
- "python": random.getstate(),
- "numpy": np.random.get_state(),
- "cpu": torch.random.get_rng_state(),
- }
- if torch.cuda.is_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
- rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
- else:
- rng_states["cuda"] = torch.cuda.random.get_rng_state()
- if is_torch_xla_available():
- rng_states["xla"] = xm.get_rng_state()
- if is_torch_npu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["npu"] = torch.npu.random.get_rng_state_all()
- else:
- rng_states["npu"] = torch.npu.random.get_rng_state()
- if is_torch_mlu_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
- else:
- rng_states["mlu"] = torch.mlu.random.get_rng_state()
- if is_torch_musa_available():
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- rng_states["musa"] = torch.musa.get_rng_state_all()
- else:
- rng_states["musa"] = torch.musa.get_rng_state()
- # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
- # not yet exist.
- os.makedirs(output_dir, exist_ok=True)
- if self.args.world_size <= 1:
- torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
- else:
- torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
- def _save_optimizer_and_scheduler(self, output_dir):
- if is_torch_xla_available():
- xm.rendezvous("saving_optimizer_states")
- if self.is_fsdp_xla_v1_enabled:
- optm = {
- "optimizer": self.optimizer.state_dict(),
- "shard_metadata": self.model.get_shard_metadata(),
- }
- xm.save(
- optm,
- os.path.join(
- output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
- ),
- master_only=False,
- )
- else:
- xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- with warnings.catch_warnings(record=True) as caught_warnings:
- xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- elif is_sagemaker_mp_enabled():
- opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
- smp.barrier()
- if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
- smp.save(
- opt_state_dict,
- os.path.join(output_dir, OPTIMIZER_NAME),
- partial=True,
- v3=smp.state.cfg.shard_optimizer_state,
- )
- elif self.is_deepspeed_enabled:
- # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
- # config `stage3_gather_16bit_weights_on_model_save` is True
- accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
- inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
- )
- if accept_exclude_frozen_parameters and _is_peft_model(self.model):
- self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
- else:
- self.model_wrapped.save_checkpoint(output_dir)
- elif self.is_fsdp_enabled:
- # save fsdp specific ckpt for resuming from ckpt
- save_fsdp_model(
- self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
- )
- save_fsdp_optimizer(
- self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
- )
- elif self.args.should_save:
- # deepspeed.save_checkpoint above saves model/optim/sched
- torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
- # Save SCHEDULER & SCALER
- is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
- self.lr_scheduler, DeepSpeedSchedulerWrapper
- )
- if (
- self.args.should_save
- and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
- and not is_torch_xla_available()
- ):
- with warnings.catch_warnings(record=True) as caught_warnings:
- torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
- reissue_pt_warnings(caught_warnings)
- def _load_optimizer_and_scheduler(self, checkpoint):
- """If optimizer and scheduler states exist, load them."""
- if checkpoint is None:
- return
- if self.is_deepspeed_enabled:
- # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
- if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
- with warnings.catch_warnings(record=True) as caught_warnings:
- self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
- reissue_pt_warnings(caught_warnings)
- return
- checkpoint_file_exists = (
- glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
- if is_sagemaker_mp_enabled()
- else (
- os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
- or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
- or (
- os.path.isdir(checkpoint)
- and any(
- OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
- for folder_name in os.listdir(checkpoint)
- if os.path.isdir(os.path.join(checkpoint, folder_name))
- )
- )
- )
- )
- checkpoint_file_exists = (
- glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
- if self.is_fsdp_xla_v1_enabled
- else checkpoint_file_exists
- )
- if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
- # Load in optimizer and scheduler states
- if is_torch_xla_available():
- # On TPU we have to take some extra precautions to properly load the states on the right device.
- if self.is_fsdp_xla_v1_enabled:
- optimizer_state = torch.load(
- os.path.join(
- checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
- ),
- map_location="cpu",
- )
- # We only need `optimizer` when resuming from checkpoint
- optimizer_state = optimizer_state["optimizer"]
- else:
- optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
- with warnings.catch_warnings(record=True) as caught_warnings:
- lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
- reissue_pt_warnings(caught_warnings)
- xm.send_cpu_data_to_device(optimizer_state, self.args.device)
- xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
- self.optimizer.load_state_dict(optimizer_state)
- self.lr_scheduler.load_state_dict(lr_scheduler_state)
- else:
- if is_sagemaker_mp_enabled():
- if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
- # Optimizer checkpoint was saved with smp >= 1.10
- def opt_load_hook(mod, opt):
- opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
- else:
- # Optimizer checkpoint was saved with smp < 1.10
- def opt_load_hook(mod, opt):
- if IS_SAGEMAKER_MP_POST_1_10:
- opt.load_state_dict(
- smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
- )
- else:
- opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
- self.model_wrapped.register_post_step_hook(opt_load_hook)
- else:
- # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
- # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
- # likely to get OOM on CPU (since we load num_gpu times the optimizer state
- map_location = self.args.device if self.args.world_size > 1 else "cpu"
- if self.is_fsdp_enabled:
- load_fsdp_optimizer(
- self.accelerator.state.fsdp_plugin,
- self.accelerator,
- self.optimizer,
- self.model,
- checkpoint,
- **_get_fsdp_ckpt_kwargs(),
- )
- else:
- self.optimizer.load_state_dict(
- torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
- )
- with warnings.catch_warnings(record=True) as caught_warnings:
- self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
- reissue_pt_warnings(caught_warnings)
- def _load_callback_state(self):
- """If callback states exist and were passed in, restore their states if enabled"""
- if not self.args.restore_callback_states_from_checkpoint:
- return
- # Callback states are stored in stateful_callbacks
- not_found = []
- new_callbacks = []
- original_callbacks = self.callback_handler.callbacks + [self.control]
- for stored_callback, data in self.state.stateful_callbacks.items():
- if not isinstance(data, list):
- data = [data]
- if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
- # We can load/restore from multiple callbacks of the same type.
- duplicates = [
- callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
- ]
- for callback, callback_data in zip(duplicates, data):
- args = callback_data.get("args", {})
- attributes = callback_data.get("attributes", {})
- new_callback = type(callback)(**args)
- for attribute, value in attributes.items():
- setattr(new_callback, attribute, value)
- if isinstance(callback, TrainerControl):
- # Specifically for restoring the `control` state
- self.control = new_callback
- else:
- new_callbacks.append(new_callback)
- # We remove the existing callback and add it to the list of new callbacks
- self.callback_handler.remove_callback(type(new_callback))
- logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
- else:
- not_found.append(stored_callback)
- if len(not_found) > 0:
- logger.warning(
- f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
- )
- for callback in new_callbacks:
- self.callback_handler.add_callback(callback)
- def hyperparameter_search(
- self,
- hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
- compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
- n_trials: int = 20,
- direction: Union[str, List[str]] = "minimize",
- backend: Optional[Union["str", HPSearchBackend]] = None,
- hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
- **kwargs,
- ) -> Union[BestRun, List[BestRun]]:
- """
- Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
- by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
- the sum of all metrics otherwise.
- <Tip warning={true}>
- To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
- reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
- subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
- optimizer/scheduler.
- </Tip>
- Args:
- hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
- A function that defines the hyperparameter search space. Will default to
- [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
- [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
- compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
- A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
- method. Will default to [`~trainer_utils.default_compute_objective`].
- n_trials (`int`, *optional*, defaults to 100):
- The number of trial runs to test.
- direction (`str` or `List[str]`, *optional*, defaults to `"minimize"`):
- If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
- should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
- several metrics. If it's multi objectives optimization, direction is `List[str]`, can be List of
- `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
- `"maximize"` when optimizing one or several metrics.
- backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
- The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
- on which one is installed. If all are installed, will default to optuna.
- hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
- A function that defines the trial/run name. Will default to None.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
- information see:
- - the documentation of
- [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
- - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
- - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
- Returns:
- [`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best
- runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
- backend.
- """
- if backend is None:
- backend = default_hp_search_backend()
- backend = HPSearchBackend(backend)
- backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
- backend_obj.ensure_available()
- self.hp_search_backend = backend
- if self.model_init is None:
- raise RuntimeError(
- "To use hyperparameter search, you need to pass your model through a model_init function."
- )
- self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
- self.hp_name = hp_name
- self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
- best_run = backend_obj.run(self, n_trials, direction, **kwargs)
- self.hp_search_backend = None
- return best_run
- def log(self, logs: Dict[str, float]) -> None:
- """
- Log `logs` on the various objects watching training.
- Subclass and override this method to inject custom behavior.
- Args:
- logs (`Dict[str, float]`):
- The values to log.
- """
- if self.state.epoch is not None:
- logs["epoch"] = self.state.epoch
- if self.args.include_num_input_tokens_seen:
- logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
- output = {**logs, **{"step": self.state.global_step}}
- self.state.log_history.append(output)
- self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
- def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
- """
- Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
- """
- if isinstance(data, Mapping):
- return type(data)({k: self._prepare_input(v) for k, v in data.items()})
- elif isinstance(data, (tuple, list)):
- return type(data)(self._prepare_input(v) for v in data)
- elif isinstance(data, torch.Tensor):
- kwargs = {"device": self.args.device}
- if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
- # NLP models inputs are int/uint and those get adjusted to the right dtype of the
- # embedding. Other models such as wav2vec2's inputs are already float and thus
- # may need special handling to match the dtypes of the model
- kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
- return data.to(**kwargs)
- return data
- def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
- """
- Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
- handling potential state.
- """
- inputs = self._prepare_input(inputs)
- if len(inputs) == 0:
- raise ValueError(
- "The batch received was empty, your model won't be able to train on it. Double-check that your "
- f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
- )
- if self.args.past_index >= 0 and self._past is not None:
- inputs["mems"] = self._past
- return inputs
- def compute_loss_context_manager(self):
- """
- A helper wrapper to group together context managers.
- """
- return self.autocast_smart_context_manager()
- def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
- """
- A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
- arguments, depending on the situation.
- """
- if self.use_cpu_amp:
- ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
- else:
- ctx_manager = contextlib.nullcontext()
- return ctx_manager
- def training_step(
- self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
- ) -> torch.Tensor:
- """
- Perform a training step on a batch of inputs.
- Subclass and override to inject custom behavior.
- Args:
- model (`nn.Module`):
- The model to train.
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- Return:
- `torch.Tensor`: The tensor with training loss on this batch.
- """
- model.train()
- if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
- self.optimizer.train()
- inputs = self._prepare_inputs(inputs)
- if is_sagemaker_mp_enabled():
- loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
- return loss_mb.reduce_mean().detach().to(self.args.device)
- with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
- del inputs
- if (
- self.args.torch_empty_cache_steps is not None
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
- ):
- if is_torch_xpu_available():
- torch.xpu.empty_cache()
- elif is_torch_mlu_available():
- torch.mlu.empty_cache()
- elif is_torch_musa_available():
- torch.musa.empty_cache()
- elif is_torch_npu_available():
- torch.npu.empty_cache()
- elif is_torch_mps_available(min_version="2.0"):
- torch.mps.empty_cache()
- else:
- torch.cuda.empty_cache()
- kwargs = {}
- # For LOMO optimizers you need to explicitly use the learnign rate
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
- kwargs["learning_rate"] = self._get_learning_rate()
- if self.args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel training
- if self.use_apex:
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- self.accelerator.backward(loss, **kwargs)
- # Finally we need to normalize the loss for reporting
- if num_items_in_batch is None:
- return loss.detach() / self.args.gradient_accumulation_steps
- return loss.detach()
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
- """
- How the loss is computed by Trainer. By default, all models return the loss in the first element.
- Subclass and override for custom behavior.
- """
- if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
- labels = inputs.pop("labels")
- else:
- labels = None
- if self.model_accepts_loss_kwargs:
- loss_kwargs = {}
- if num_items_in_batch is not None:
- loss_kwargs["num_items_in_batch"] = num_items_in_batch
- inputs = {**inputs, **loss_kwargs}
- outputs = model(**inputs)
- # Save past state if it exists
- # TODO: this needs to be fixed and made cleaner later.
- if self.args.past_index >= 0:
- self._past = outputs[self.args.past_index]
- if labels is not None:
- unwrapped_model = self.accelerator.unwrap_model(model)
- if _is_peft_model(unwrapped_model):
- model_name = unwrapped_model.base_model.model._get_name()
- else:
- model_name = unwrapped_model._get_name()
- # User-defined compute_loss function
- if self.compute_loss_func is not None:
- loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
- elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
- loss = self.label_smoother(outputs, labels, shift_labels=True)
- else:
- loss = self.label_smoother(outputs, labels)
- else:
- if isinstance(outputs, dict) and "loss" not in outputs:
- raise ValueError(
- "The model did not return a loss from the inputs, only the following keys: "
- f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
- )
- # We don't use .loss here since the model may return tuples instead of ModelOutput.
- loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
- if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
- loss *= self.accelerator.num_processes
- return (loss, outputs) if return_outputs else loss
- def is_local_process_zero(self) -> bool:
- """
- Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
- machines) main process.
- """
- return self.args.local_process_index == 0
- def is_world_process_zero(self) -> bool:
- """
- Whether or not this process is the global main process (when training in a distributed fashion on several
- machines, this is only going to be `True` for one process).
- """
- # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
- # process index.
- if is_sagemaker_mp_enabled():
- return smp.rank() == 0
- else:
- return self.args.process_index == 0
- def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
- """
- Will save the model, so you can reload it using `from_pretrained()`.
- Will only save from the main process.
- """
- if output_dir is None:
- output_dir = self.args.output_dir
- if is_torch_xla_available():
- self._save_tpu(output_dir)
- elif is_sagemaker_mp_enabled():
- # Calling the state_dict needs to be done on the wrapped model and on all processes.
- os.makedirs(output_dir, exist_ok=True)
- state_dict = self.model_wrapped.state_dict()
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- if IS_SAGEMAKER_MP_POST_1_10:
- # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
- Path(os.path.join(output_dir, "user_content.pt")).touch()
- elif self.is_fsdp_enabled:
- if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
- version.parse(accelerate_version) > version.parse("0.24.1")
- ):
- state_dict = self.accelerator.get_state_dict(self.model)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- elif self.is_deepspeed_enabled:
- try:
- state_dict = self.accelerator.get_state_dict(self.deepspeed)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- except ValueError:
- logger.warning(
- " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
- " zero_to_fp32.py to recover weights"
- )
- if self.args.should_save:
- self._save(output_dir, state_dict={})
- # remove the dummy state_dict
- remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
- self.model_wrapped.save_checkpoint(output_dir)
- elif self.args.should_save:
- self._save(output_dir)
- # Push to the Hub when `save_model` is called by the user.
- if self.args.push_to_hub and not _internal_call:
- self.push_to_hub(commit_message="Model save")
- def _save_tpu(self, output_dir: Optional[str] = None):
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- logger.info(f"Saving model checkpoint to {output_dir}")
- model = self.model
- xm.mark_step()
- if xm.is_master_ordinal(local=False):
- os.makedirs(output_dir, exist_ok=True)
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- # Save a trained model and configuration using `save_pretrained()`.
- # They can then be reloaded using `from_pretrained()`
- supported_classes = (PushToHubMixin,)
- xm.rendezvous("saving_checkpoint")
- if self.is_fsdp_xla_v1_enabled:
- ckpt = {
- "model": model.state_dict(),
- "shard_metadata": model.get_shard_metadata(),
- }
- ckpt_path = os.path.join(
- output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
- )
- # All ranks save sharded checkpoint
- xm.save(ckpt, ckpt_path, master_only=False)
- # Make sure all ranks have saved checkpoints
- xm.rendezvous("save_full_checkpoints")
- # Master save full checkpoint
- if self.args.should_save:
- from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
- full_state_dict, _ = consolidate_sharded_model_checkpoints(
- ckpt_prefix=os.path.join(output_dir, ""),
- ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
- save_model=False,
- )
- model = model.module.module
- unwrapped_model = self.accelerator.unwrap_model(model)
- if isinstance(unwrapped_model, supported_classes):
- unwrapped_model.save_pretrained(
- output_dir,
- state_dict=full_state_dict,
- save_function=xm.save,
- safe_serialization=self.args.save_safetensors,
- )
- else:
- logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
- xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
- elif not isinstance(model, supported_classes):
- if isinstance(self.accelerator.unwrap_model(model), supported_classes):
- self.accelerator.unwrap_model(model).save_pretrained(
- output_dir,
- is_main_process=self.args.should_save,
- state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
- save_function=xm.save,
- safe_serialization=self.args.save_safetensors,
- )
- else:
- logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
- state_dict = xm._maybe_convert_to_cpu(model.state_dict())
- xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
- else:
- model.save_pretrained(
- output_dir,
- is_main_process=self.args.should_save,
- save_function=xm.save,
- safe_serialization=self.args.save_safetensors,
- state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
- )
- if self.processing_class is not None and self.args.should_save:
- self.processing_class.save_pretrained(output_dir)
- def _save(self, output_dir: Optional[str] = None, state_dict=None):
- # If we are executing this function, we are the process zero, so we don't check for that.
- output_dir = output_dir if output_dir is not None else self.args.output_dir
- os.makedirs(output_dir, exist_ok=True)
- logger.info(f"Saving model checkpoint to {output_dir}")
- supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
- # Save a trained model and configuration using `save_pretrained()`.
- # They can then be reloaded using `from_pretrained()`
- if not isinstance(self.model, supported_classes):
- if state_dict is None:
- state_dict = self.model.state_dict()
- if isinstance(self.accelerator.unwrap_model(self.model), supported_classes):
- self.accelerator.unwrap_model(self.model).save_pretrained(
- output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
- )
- else:
- logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
- if self.args.save_safetensors:
- safetensors.torch.save_file(
- state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
- )
- else:
- torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
- else:
- self.model.save_pretrained(
- output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
- )
- if self.processing_class is not None:
- self.processing_class.save_pretrained(output_dir)
- # Good practice: save your training arguments together with the trained model
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- def store_flos(self):
- # Storing the number of floating-point operations that went into the model
- if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- self.state.total_flos += (
- distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
- )
- self.current_flos = 0
- else:
- self.state.total_flos += self.current_flos
- self.current_flos = 0
- def _sorted_checkpoints(
- self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
- ) -> List[str]:
- ordering_and_checkpoint_path = []
- glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
- for path in glob_checkpoints:
- if use_mtime:
- ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
- else:
- regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
- if regex_match is not None and regex_match.groups() is not None:
- ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
- checkpoints_sorted = sorted(ordering_and_checkpoint_path)
- checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
- # Make sure we don't delete the best model.
- if (
- self.state.best_model_checkpoint is not None
- and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
- ):
- best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
- for i in range(best_model_index, len(checkpoints_sorted) - 2):
- checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
- return checkpoints_sorted
- def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
- if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
- return
- # Check if we should delete older checkpoint(s)
- checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
- if len(checkpoints_sorted) <= self.args.save_total_limit:
- return
- # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
- # we don't do to allow resuming.
- save_total_limit = self.args.save_total_limit
- if (
- self.state.best_model_checkpoint is not None
- and self.args.save_total_limit == 1
- and checkpoints_sorted[-1] != self.state.best_model_checkpoint
- ):
- save_total_limit = 2
- number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
- checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
- for checkpoint in checkpoints_to_be_deleted:
- logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
- shutil.rmtree(checkpoint, ignore_errors=True)
- def evaluate(
- self,
- eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
- ignore_keys: Optional[List[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> Dict[str, float]:
- """
- Run evaluation and returns metrics.
- The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
- (pass it to the init `compute_metrics` argument).
- You can also subclass and override this method to inject custom behavior.
- Args:
- eval_dataset (Union[`Dataset`, Dict[str, `Dataset`]), *optional*):
- Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
- not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
- evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
- `__len__` method.
- <Tip>
- If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
- separate evaluations on each dataset. This can be useful to monitor how training affects other
- datasets or simply to get a more fine-grained evaluation.
- When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
- of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
- `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
- loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`.
- </Tip>
- ignore_keys (`List[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
- An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
- "eval_bleu" if the prefix is "eval" (default)
- Returns:
- A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
- dictionary also contains the epoch number which comes from the training state.
- """
- # handle multipe eval datasets
- override = eval_dataset is not None
- eval_dataset = eval_dataset if override else self.eval_dataset
- if isinstance(eval_dataset, dict):
- metrics = {}
- for eval_dataset_name, _eval_dataset in eval_dataset.items():
- dataset_metrics = self.evaluate(
- eval_dataset=_eval_dataset if override else eval_dataset_name,
- ignore_keys=ignore_keys,
- metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
- )
- metrics.update(dataset_metrics)
- return metrics
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- eval_dataloader = self.get_eval_dataloader(eval_dataset)
- if self.is_fsdp_xla_v2_enabled:
- eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
- start_time = time.time()
- eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
- output = eval_loop(
- eval_dataloader,
- description="Evaluation",
- # No point gathering the predictions if there are no metrics, otherwise we defer to
- # self.args.prediction_loss_only
- prediction_loss_only=True if self.compute_metrics is None else None,
- ignore_keys=ignore_keys,
- metric_key_prefix=metric_key_prefix,
- )
- total_batch_size = self.args.eval_batch_size * self.args.world_size
- if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
- if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
- output.metrics.update(
- speed_metrics(
- metric_key_prefix,
- start_time,
- num_samples=output.num_samples,
- num_steps=math.ceil(output.num_samples / total_batch_size),
- )
- )
- self.log(output.metrics)
- if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
- # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
- xm.master_print(met.metrics_report())
- self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
- self._memory_tracker.stop_and_update_metrics(output.metrics)
- return output.metrics
- def predict(
- self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
- ) -> PredictionOutput:
- """
- Run prediction and returns predictions and potential metrics.
- Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
- will also return metrics, like in `evaluate()`.
- Args:
- test_dataset (`Dataset`):
- Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
- `model.forward()` method are automatically removed. Has to implement the method `__len__`
- ignore_keys (`List[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- metric_key_prefix (`str`, *optional*, defaults to `"test"`):
- An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
- "test_bleu" if the prefix is "test" (default)
- <Tip>
- If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
- in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
- one array. The padding index is -100.
- </Tip>
- Returns: *NamedTuple* A namedtuple with the following keys:
- - predictions (`np.ndarray`): The predictions on `test_dataset`.
- - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
- - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
- labels).
- """
- # memory metrics - must set up as early as possible
- self._memory_tracker.start()
- test_dataloader = self.get_test_dataloader(test_dataset)
- start_time = time.time()
- eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
- output = eval_loop(
- test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
- )
- total_batch_size = self.args.eval_batch_size * self.args.world_size
- if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
- if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
- start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
- output.metrics.update(
- speed_metrics(
- metric_key_prefix,
- start_time,
- num_samples=output.num_samples,
- num_steps=math.ceil(output.num_samples / total_batch_size),
- )
- )
- self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
- self._memory_tracker.stop_and_update_metrics(output.metrics)
- return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
- def evaluation_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[List[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
- Works both with or without labels.
- """
- args = self.args
- prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
- # if eval is called w/o train, handle model prep here
- if self.is_deepspeed_enabled and self.deepspeed is None:
- _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
- model = self._wrap_model(self.model, training=False, dataloader=dataloader)
- if len(self.accelerator._models) == 0 and model is self.model:
- start_time = time.time()
- model = (
- self.accelerator.prepare(model)
- if self.is_deepspeed_enabled or self.is_fsdp_enabled
- else self.accelerator.prepare_model(model, evaluation_mode=True)
- )
- self.model_preparation_time = round(time.time() - start_time, 4)
- if self.is_fsdp_enabled:
- self.model = model
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
- if model is not self.model:
- self.model_wrapped = model
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model_wrapped
- # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
- # while ``train`` is running, cast it to the right dtype first and then put on device
- if not self.is_in_train:
- if args.fp16_full_eval:
- model = model.to(dtype=torch.float16, device=args.device)
- elif args.bf16_full_eval:
- model = model.to(dtype=torch.bfloat16, device=args.device)
- batch_size = self.args.eval_batch_size
- logger.info(f"\n***** Running {description} *****")
- if has_length(dataloader):
- logger.info(f" Num examples = {self.num_examples(dataloader)}")
- else:
- logger.info(" Num examples: Unknown")
- logger.info(f" Batch size = {batch_size}")
- model.eval()
- if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
- self.optimizer.eval()
- self.callback_handler.eval_dataloader = dataloader
- # Do this before wrapping.
- eval_dataset = getattr(dataloader, "dataset", None)
- if args.past_index >= 0:
- self._past = None
- # Initialize containers
- all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
- metrics = None
- eval_set_kwargs = {}
- # Will be useful when we have an iterable dataset so don't know its length.
- observed_num_examples = 0
- # Main evaluation loop
- for step, inputs in enumerate(dataloader):
- # Update the observed num examples
- observed_batch_size = find_batch_size(inputs)
- if observed_batch_size is not None:
- observed_num_examples += observed_batch_size
- # For batch samplers, batch_size is not known by the dataloader in advance.
- if batch_size is None:
- batch_size = observed_batch_size
- # Prediction step
- losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
- main_input_name = getattr(self.model, "main_input_name", "input_ids")
- inputs_decode = (
- self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
- )
- if is_torch_xla_available():
- xm.mark_step()
- # Update containers
- if losses is not None:
- losses = self.gather_function((losses.repeat(batch_size)))
- all_losses.add(losses)
- if inputs_decode is not None:
- inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
- inputs_decode = self.gather_function((inputs_decode))
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_inputs.add(inputs_decode)
- if labels is not None:
- # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
- labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
- if logits is not None:
- logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
- if self.preprocess_logits_for_metrics is not None:
- logits = self.preprocess_logits_for_metrics(logits, labels)
- logits = self.gather_function((logits))
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_preds.add(logits)
- if labels is not None:
- labels = self.gather_function((labels))
- if not self.args.batch_eval_metrics or description == "Prediction":
- all_labels.add(labels)
- self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
- if self.args.batch_eval_metrics:
- if self.compute_metrics is not None and logits is not None and labels is not None:
- is_last_step = self.accelerator.gradient_state.end_of_dataloader
- batch_kwargs = {}
- batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
- batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(
- EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
- compute_result=is_last_step,
- )
- del losses, logits, labels, inputs
- torch.cuda.empty_cache()
- # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
- elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
- all_losses.to_cpu_and_numpy()
- all_preds.to_cpu_and_numpy()
- all_labels.to_cpu_and_numpy()
- all_inputs.to_cpu_and_numpy()
- del losses, logits, labels, inputs
- torch.cuda.empty_cache()
- # After all calls to `.gather_function`, reset to `gather_for_metrics`:
- self.gather_function = self.accelerator.gather_for_metrics
- if args.past_index and hasattr(self, "_past"):
- # Clean the state at the end of the evaluation loop
- delattr(self, "_past")
- # Gather all remaining tensors and put them back on the CPU
- all_losses = all_losses.get_arrays()
- all_preds = all_preds.get_arrays()
- all_labels = all_labels.get_arrays()
- all_inputs = all_inputs.get_arrays()
- # Number of samples
- if has_length(eval_dataset):
- num_samples = len(eval_dataset)
- # The instance check is weird and does not actually check for the type, but whether the dataset has the right
- # methods. Therefore we need to make sure it also has the attribute.
- elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
- num_samples = eval_dataset.num_examples
- else:
- if has_length(dataloader):
- num_samples = self.num_examples(dataloader)
- else: # both len(dataloader.dataset) and len(dataloader) fail
- num_samples = observed_num_examples
- if num_samples == 0 and observed_num_examples > 0:
- num_samples = observed_num_examples
- # Metrics!
- if (
- self.compute_metrics is not None
- and all_preds is not None
- and all_labels is not None
- and not self.args.batch_eval_metrics
- ):
- eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
- eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(
- EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
- )
- elif metrics is None:
- metrics = {}
- # To be JSON-serializable, we need to remove numpy types or zero-d tensors
- metrics = denumpify_detensorize(metrics)
- if isinstance(all_losses, list) and all_losses:
- metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
- elif isinstance(all_losses, np.ndarray):
- metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
- if hasattr(self, "jit_compilation_time"):
- metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
- if hasattr(self, "model_preparation_time"):
- metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
- # Prefix all keys with metric_key_prefix + '_'
- for key in list(metrics.keys()):
- if not key.startswith(f"{metric_key_prefix}_"):
- metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
- return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
- def _nested_gather(self, tensors, name=None):
- """
- Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
- concatenating them to `gathered`
- """
- if tensors is None:
- return
- if is_torch_xla_available():
- if name is None:
- name = "nested_gather"
- tensors = nested_xla_mesh_reduce(tensors, name)
- elif is_sagemaker_mp_enabled():
- tensors = smp_gather(tensors)
- elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
- self.args.distributed_state is None and self.args.local_rank != -1
- ):
- tensors = distributed_concat(tensors)
- return tensors
- def prediction_step(
- self,
- model: nn.Module,
- inputs: Dict[str, Union[torch.Tensor, Any]],
- prediction_loss_only: bool,
- ignore_keys: Optional[List[str]] = None,
- ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
- """
- Perform an evaluation step on `model` using `inputs`.
- Subclass and override to inject custom behavior.
- Args:
- model (`nn.Module`):
- The model to evaluate.
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- prediction_loss_only (`bool`):
- Whether or not to return the loss only.
- ignore_keys (`List[str]`, *optional*):
- A list of keys in the output of your model (if it is a dictionary) that should be ignored when
- gathering predictions.
- Return:
- Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
- logits and labels (each being optional).
- """
- has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
- # For CLIP-like models capable of returning loss values.
- # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
- # is `True` in `model.forward`.
- return_loss = inputs.get("return_loss", None)
- if return_loss is None:
- return_loss = self.can_return_loss
- loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
- inputs = self._prepare_inputs(inputs)
- if ignore_keys is None:
- if hasattr(self.model, "config"):
- ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
- else:
- ignore_keys = []
- # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
- if has_labels or loss_without_labels:
- labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
- if len(labels) == 1:
- labels = labels[0]
- else:
- labels = None
- with torch.no_grad():
- if is_sagemaker_mp_enabled():
- raw_outputs = smp_forward_only(model, inputs)
- if has_labels or loss_without_labels:
- if isinstance(raw_outputs, dict):
- loss_mb = raw_outputs["loss"]
- logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
- else:
- loss_mb = raw_outputs[0]
- logits_mb = raw_outputs[1:]
- loss = loss_mb.reduce_mean().detach().cpu()
- logits = smp_nested_concat(logits_mb)
- else:
- loss = None
- if isinstance(raw_outputs, dict):
- logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
- else:
- logits_mb = raw_outputs
- logits = smp_nested_concat(logits_mb)
- else:
- if has_labels or loss_without_labels:
- with self.compute_loss_context_manager():
- loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
- loss = loss.mean().detach()
- if isinstance(outputs, dict):
- logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
- else:
- logits = outputs[1:]
- else:
- loss = None
- with self.compute_loss_context_manager():
- outputs = model(**inputs)
- if isinstance(outputs, dict):
- logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
- else:
- logits = outputs
- # TODO: this needs to be fixed and made cleaner later.
- if self.args.past_index >= 0:
- self._past = outputs[self.args.past_index - 1]
- if prediction_loss_only:
- return (loss, None, None)
- logits = nested_detach(logits)
- if len(logits) == 1:
- logits = logits[0]
- return (loss, logits, labels)
- def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
- """
- For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
- operations for every backward + forward pass. If using another model, either implement such a method in the
- model or subclass and override this method.
- Args:
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
- The inputs and targets of the model.
- Returns:
- `int`: The number of floating-point operations.
- """
- if hasattr(self.model, "floating_point_ops"):
- return self.model.floating_point_ops(inputs)
- else:
- return 0
- def init_hf_repo(self, token: Optional[str] = None):
- """
- Initializes a git repo in `self.args.hub_model_id`.
- """
- # Only on process zero
- if not self.is_world_process_zero():
- return
- if self.args.hub_model_id is None:
- repo_name = Path(self.args.output_dir).absolute().name
- else:
- repo_name = self.args.hub_model_id
- token = token if token is not None else self.args.hub_token
- repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
- self.hub_model_id = repo_url.repo_id
- self.push_in_progress = None
- def create_model_card(
- self,
- language: Optional[str] = None,
- license: Optional[str] = None,
- tags: Union[str, List[str], None] = None,
- model_name: Optional[str] = None,
- finetuned_from: Optional[str] = None,
- tasks: Union[str, List[str], None] = None,
- dataset_tags: Union[str, List[str], None] = None,
- dataset: Union[str, List[str], None] = None,
- dataset_args: Union[str, List[str], None] = None,
- ):
- """
- Creates a draft of a model card using the information available to the `Trainer`.
- Args:
- language (`str`, *optional*):
- The language of the model (if applicable)
- license (`str`, *optional*):
- The license of the model. Will default to the license of the pretrained model used, if the original
- model given to the `Trainer` comes from a repo on the Hub.
- tags (`str` or `List[str]`, *optional*):
- Some tags to be included in the metadata of the model card.
- model_name (`str`, *optional*):
- The name of the model.
- finetuned_from (`str`, *optional*):
- The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
- of the original model given to the `Trainer` (if it comes from the Hub).
- tasks (`str` or `List[str]`, *optional*):
- One or several task identifiers, to be included in the metadata of the model card.
- dataset_tags (`str` or `List[str]`, *optional*):
- One or several dataset tags, to be included in the metadata of the model card.
- dataset (`str` or `List[str]`, *optional*):
- One or several dataset identifiers, to be included in the metadata of the model card.
- dataset_args (`str` or `List[str]`, *optional*):
- One or several dataset arguments, to be included in the metadata of the model card.
- """
- if not self.is_world_process_zero():
- return
- model_card_filepath = os.path.join(self.args.output_dir, "README.md")
- is_peft_library = False
- if os.path.exists(model_card_filepath):
- library_name = ModelCard.load(model_card_filepath).data.get("library_name")
- is_peft_library = library_name == "peft"
- # Append existing tags in `tags`
- existing_tags = ModelCard.load(model_card_filepath).data.tags
- if tags is not None and existing_tags is not None:
- if isinstance(tags, str):
- tags = [tags]
- for tag in existing_tags:
- if tag not in tags:
- tags.append(tag)
- training_summary = TrainingSummary.from_trainer(
- self,
- language=language,
- license=license,
- tags=tags,
- model_name=model_name,
- finetuned_from=finetuned_from,
- tasks=tasks,
- dataset_tags=dataset_tags,
- dataset=dataset,
- dataset_args=dataset_args,
- )
- model_card = training_summary.to_model_card()
- with open(model_card_filepath, "w") as f:
- f.write(model_card)
- if is_peft_library:
- self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
- def _push_from_checkpoint(self, checkpoint_folder):
- # Only push from one node.
- if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
- return
- # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
- if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
- return
- output_dir = self.args.output_dir
- # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
- modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
- # Add sharded checkpoints if we have an index
- for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
- index_path = os.path.join(checkpoint_folder, index_file)
- if os.path.isfile(index_path):
- modeling_files.append(index_file)
- with open(index_path) as f:
- index = json.loads(f.read())
- shard_files = list(set(index["weight_map"].values()))
- modeling_files.extend(shard_files)
- if is_peft_available():
- modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
- for modeling_file in modeling_files:
- if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
- shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
- # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
- if self.processing_class is not None:
- self.processing_class.save_pretrained(output_dir)
- # Same for the training arguments
- torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
- if self.args.save_strategy == IntervalStrategy.STEPS:
- commit_message = f"Training in progress, step {self.state.global_step}"
- else:
- commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
- model_push_job = upload_folder(
- repo_id=self.hub_model_id,
- folder_path=output_dir,
- commit_message=commit_message,
- token=self.args.hub_token,
- run_as_future=True,
- ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
- )
- push_jobs = [model_push_job]
- if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
- path_in_repo = (
- "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
- )
- checkpoint_push = upload_folder(
- repo_id=self.hub_model_id,
- folder_path=checkpoint_folder,
- path_in_repo=path_in_repo,
- commit_message=commit_message + ", checkpoint",
- token=self.args.hub_token,
- run_as_future=True,
- )
- push_jobs.append(checkpoint_push)
- if self.push_in_progress is None or self.push_in_progress.is_done():
- self.push_in_progress = PushInProgress(push_jobs)
- else:
- self.push_in_progress.jobs.extend(push_jobs)
- def _finish_current_push(self):
- if not hasattr(self, "push_in_progress"):
- return
- if self.push_in_progress is not None and not self.push_in_progress.is_done():
- logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
- self.push_in_progress.wait_until_done()
- def push_to_hub(
- self,
- commit_message: Optional[str] = "End of training",
- blocking: bool = True,
- token: Optional[str] = None,
- revision: Optional[str] = None,
- **kwargs,
- ) -> str:
- """
- Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.
- Parameters:
- commit_message (`str`, *optional*, defaults to `"End of training"`):
- Message to commit while pushing.
- blocking (`bool`, *optional*, defaults to `True`):
- Whether the function should return only when the `git push` has finished.
- token (`str`, *optional*, defaults to `None`):
- Token with write permission to overwrite Trainer's original args.
- revision (`str`, *optional*):
- The git revision to commit from. Defaults to the head of the "main" branch.
- kwargs (`Dict[str, Any]`, *optional*):
- Additional keyword arguments passed along to [`~Trainer.create_model_card`].
- Returns:
- The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
- progress of the commit if `blocking=True`.
- """
- model_name = kwargs.pop("model_name", None)
- if model_name is None and self.args.should_save:
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- token = token if token is not None else self.args.hub_token
- # In case the user calls this method with args.push_to_hub = False
- if self.hub_model_id is None:
- self.init_hf_repo(token=token)
- # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
- # self.args.should_save.
- self.save_model(_internal_call=True)
- # Only push from one node.
- if not self.is_world_process_zero():
- return
- # Add additional tags in the case the model has already some tags and users pass
- # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
- # from all models since Trainer does not call `model.push_to_hub`.
- if getattr(self.model, "model_tags", None) is not None:
- if "tags" not in kwargs:
- kwargs["tags"] = []
- # If it is a string, convert it to a list
- if isinstance(kwargs["tags"], str):
- kwargs["tags"] = [kwargs["tags"]]
- for model_tag in self.model.model_tags:
- if model_tag not in kwargs["tags"]:
- kwargs["tags"].append(model_tag)
- self.create_model_card(model_name=model_name, **kwargs)
- # Wait for the current upload to be finished.
- self._finish_current_push()
- return upload_folder(
- repo_id=self.hub_model_id,
- folder_path=self.args.output_dir,
- commit_message=commit_message,
- token=token,
- run_as_future=not blocking,
- ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
- revision=revision,
- )
- #
- # Deprecated code
- #
- def prediction_loop(
- self,
- dataloader: DataLoader,
- description: str,
- prediction_loss_only: Optional[bool] = None,
- ignore_keys: Optional[List[str]] = None,
- metric_key_prefix: str = "eval",
- ) -> EvalLoopOutput:
- """
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
- Works both with or without labels.
- """
- args = self.args
- if not has_length(dataloader):
- raise ValueError("dataloader must implement a working __len__")
- prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
- # if eval is called w/o train, handle model prep here
- if self.is_deepspeed_enabled and self.deepspeed is None:
- _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
- model = self._wrap_model(self.model, training=False, dataloader=dataloader)
- if len(self.accelerator._models) == 0 and model is self.model:
- model = (
- self.accelerator.prepare(model)
- if self.is_deepspeed_enabled or self.is_fsdp_enabled
- else self.accelerator.prepare_model(model, evaluation_mode=True)
- )
- if self.is_fsdp_enabled:
- self.model = model
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
- if model is not self.model:
- self.model_wrapped = model
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model_wrapped
- # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
- # while ``train`` is running, cast it to the right dtype first and then put on device
- if not self.is_in_train:
- if args.fp16_full_eval:
- model = model.to(dtype=torch.float16, device=args.device)
- elif args.bf16_full_eval:
- model = model.to(dtype=torch.bfloat16, device=args.device)
- batch_size = dataloader.batch_size
- num_examples = self.num_examples(dataloader)
- logger.info(f"\n***** Running {description} *****")
- logger.info(f" Num examples = {num_examples}")
- logger.info(f" Batch size = {batch_size}")
- losses_host: torch.Tensor = None
- preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
- labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
- inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
- metrics: Optional[dict] = None
- eval_set_kwargs: dict = {}
- world_size = max(1, args.world_size)
- eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
- if not prediction_loss_only:
- # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
- # a batch size to the sampler)
- make_multiple_of = None
- if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
- make_multiple_of = dataloader.sampler.batch_size
- preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
- labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
- inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
- model.eval()
- if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
- self.optimizer.eval()
- if args.past_index >= 0:
- self._past = None
- self.callback_handler.eval_dataloader = dataloader
- for step, inputs in enumerate(dataloader):
- loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
- main_input_name = getattr(self.model, "main_input_name", "input_ids")
- inputs_decode = (
- self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
- )
- if loss is not None:
- losses = loss.repeat(batch_size)
- losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
- if logits is not None:
- preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
- if labels is not None:
- labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
- if inputs_decode is not None:
- inputs_host = (
- inputs_decode
- if inputs_host is None
- else nested_concat(inputs_host, inputs_decode, padding_index=-100)
- )
- self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
- if self.args.batch_eval_metrics:
- if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
- is_last_step = self.accelerator.gradient_state.end_of_dataloader
- batch_kwargs = {}
- batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
- batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(
- EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
- compute_result=is_last_step,
- )
- if self.args.batch_eval_metrics or (
- args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
- ):
- # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
- eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
- if not prediction_loss_only:
- preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
- labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
- inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
- # Set back to None to begin a new accumulation
- del losses_host, preds_host, labels_host, inputs_host
- torch.cuda.empty_cache()
- losses_host, preds_host, labels_host, inputs_host = None, None, None, None
- if args.past_index and hasattr(self, "_past"):
- # Clean the state at the end of the evaluation loop
- delattr(self, "_past")
- # Gather all remaining tensors and put them back on the CPU
- eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
- if not prediction_loss_only:
- preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
- labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
- inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
- eval_loss = eval_losses_gatherer.finalize()
- preds = preds_gatherer.finalize() if not prediction_loss_only else None
- label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
- inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
- if (
- self.compute_metrics is not None
- and preds is not None
- and label_ids is not None
- and not self.args.batch_eval_metrics
- ):
- eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
- eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
- metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
- elif metrics is None:
- metrics = {}
- # To be JSON-serializable, we need to remove numpy types or zero-d tensors
- metrics = denumpify_detensorize(metrics)
- if eval_loss is not None:
- metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
- # Prefix all keys with metric_key_prefix + '_'
- for key in list(metrics.keys()):
- if not key.startswith(f"{metric_key_prefix}_"):
- metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
- return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
- def _gather_and_numpify(self, tensors, name):
- """
- Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
- concatenating them to `gathered`
- """
- if tensors is None:
- return
- if is_torch_xla_available():
- tensors = nested_xla_mesh_reduce(tensors, name)
- elif is_sagemaker_mp_enabled():
- tensors = smp_gather(tensors)
- elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
- tensors = distributed_concat(tensors)
- return nested_numpify(tensors)
- def _add_sm_patterns_to_gitignore(self) -> None:
- """Add SageMaker Checkpointing patterns to .gitignore file."""
- # Make sure we only do this on the main process
- if not self.is_world_process_zero():
- return
- patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
- # Get current .gitignore content
- if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
- with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
- current_content = f.read()
- else:
- current_content = ""
- # Add the patterns to .gitignore
- content = current_content
- for pattern in patterns:
- if pattern not in content:
- if content.endswith("\n"):
- content += pattern
- else:
- content += f"\n{pattern}"
- # Write the .gitignore file if it has changed
- if content != current_content:
- with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
- logger.debug(f"Writing .gitignore file. Content: {content}")
- f.write(content)
- self.repo.git_add(".gitignore")
- # avoid race condition with git status
- time.sleep(0.5)
- if not self.repo.is_repo_clean():
- self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
- self.repo.git_push()
- def create_accelerator_and_postprocess(self):
- # We explicitly don't rely on the `Accelerator` to do gradient accumulation
- grad_acc_kwargs = {}
- if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
- grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
- # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
- if "num_steps" in grad_acc_kwargs:
- if self.args.gradient_accumulation_steps > 1:
- # raise because we do not know which setting is intended.
- raise ValueError(
- "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
- "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
- )
- else:
- self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
- accelerator_config = self.args.accelerator_config.to_dict()
- if is_accelerate_available("0.28.0"):
- dataloader_config = DataLoaderConfiguration(
- split_batches=accelerator_config.pop("split_batches"),
- dispatch_batches=accelerator_config.pop("dispatch_batches"),
- even_batches=accelerator_config.pop("even_batches"),
- use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
- )
- if is_accelerate_available("1.1.0"):
- dataloader_config.data_seed = self.args.data_seed
- non_blocking = accelerator_config.pop("non_blocking")
- if not is_accelerate_available("0.30.0"):
- if non_blocking:
- raise ImportError(
- "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
- )
- else:
- if non_blocking and not self.args.dataloader_pin_memory:
- logger.warning(
- "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
- )
- dataloader_config.non_blocking = non_blocking
- # this would have been updated above, no need for it anymore
- accelerator_config.pop("gradient_accumulation_kwargs")
- args = {
- "deepspeed_plugin": self.args.deepspeed_plugin,
- }
- if is_accelerate_available("0.28.0"):
- args["dataloader_config"] = dataloader_config
- else:
- args.update(accelerator_config)
- # create accelerator object
- self.accelerator = Accelerator(**args)
- # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
- self.gather_function = self.accelerator.gather_for_metrics
- if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
- self.gather_function = functools.partial(
- self.gather_function, use_gather_object=self.args.eval_use_gather_object
- )
- # deepspeed and accelerate flags covering both trainer args and accelerate launcher
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- # post accelerator creation setup
- if self.is_fsdp_enabled:
- fsdp_plugin = self.accelerator.state.fsdp_plugin
- fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
- "limit_all_gathers", fsdp_plugin.limit_all_gathers
- )
- fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
- "activation_checkpointing", fsdp_plugin.activation_checkpointing
- )
- if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
- raise ValueError(
- "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
- "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
- "when using FSDP."
- )
- if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
- self.propagate_args_to_deepspeed()
- # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
- if (
- self.args.save_only_model
- and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
- and self.args.load_best_model_at_end
- ):
- wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
- raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
- # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
- if (
- self.is_deepspeed_enabled
- and self.accelerator.state.deepspeed_plugin.zero_stage == 3
- and self.args.auto_find_batch_size
- ):
- raise ValueError(
- "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
- )
- def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
- """
- Sets values in the deepspeed plugin based on the Trainer args
- """
- from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
- ds_plugin = self.accelerator.state.deepspeed_plugin
- ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
- ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
- ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
- def _fsdp_qlora_plugin_updates(self):
- if self.is_fsdp_enabled and _is_peft_model(self.model):
- from peft import LoraConfig
- from peft.utils.other import fsdp_auto_wrap_policy
- if isinstance(self.model.active_peft_config, LoraConfig):
- fsdp_plugin = self.accelerator.state.fsdp_plugin
- fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
- if (
- getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
- and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
- and version.parse(accelerate_version) > version.parse("0.27.0")
- ):
- fsdp_plugin.set_mixed_precision(
- self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
- )
- def get_batch_samples(self, epoch_iterator, num_batches):
- batch_samples = []
- num_items_in_batch = None
- for _ in range(num_batches):
- try:
- batch_samples += [next(epoch_iterator)]
- except StopIteration:
- break
- # Keep default behavior the same
- if not self.model_accepts_loss_kwargs:
- return batch_samples, None
- if len(batch_samples) > 0 and "labels" in batch_samples[0]:
- # For now we don't support object detection
- try:
- num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
- except (TypeError, AttributeError):
- pass
- if self.args.average_tokens_across_devices:
- num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
- return batch_samples, num_items_in_batch
|