modeling_utils.py 268 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import collections
  17. import copy
  18. import functools
  19. import gc
  20. import importlib.metadata
  21. import inspect
  22. import itertools
  23. import json
  24. import os
  25. import re
  26. import shutil
  27. import tempfile
  28. import warnings
  29. from contextlib import contextmanager
  30. from dataclasses import dataclass
  31. from functools import partial, wraps
  32. from threading import Thread
  33. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  34. from zipfile import is_zipfile
  35. import torch
  36. from huggingface_hub import split_torch_state_dict_into_shards
  37. from packaging import version
  38. from torch import Tensor, nn
  39. from torch.nn import CrossEntropyLoss, Identity
  40. from torch.utils.checkpoint import checkpoint
  41. from .activations import get_activation
  42. from .configuration_utils import PretrainedConfig
  43. from .dynamic_module_utils import custom_object_save
  44. from .generation import GenerationConfig, GenerationMixin
  45. from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
  46. from .loss.loss_utils import LOSS_MAPPING
  47. from .pytorch_utils import ( # noqa: F401
  48. Conv1D,
  49. apply_chunking_to_forward,
  50. find_pruneable_heads_and_indices,
  51. id_tensor_storage,
  52. is_torch_greater_or_equal_than_1_13,
  53. prune_conv1d_layer,
  54. prune_layer,
  55. prune_linear_layer,
  56. )
  57. from .quantizers import AutoHfQuantizer, HfQuantizer
  58. from .quantizers.quantizers_utils import get_module_from_name
  59. from .safetensors_conversion import auto_conversion
  60. from .utils import (
  61. ACCELERATE_MIN_VERSION,
  62. ADAPTER_SAFE_WEIGHTS_NAME,
  63. ADAPTER_WEIGHTS_NAME,
  64. CONFIG_NAME,
  65. DUMMY_INPUTS,
  66. FLAX_WEIGHTS_NAME,
  67. SAFE_WEIGHTS_INDEX_NAME,
  68. SAFE_WEIGHTS_NAME,
  69. TF2_WEIGHTS_NAME,
  70. TF_WEIGHTS_NAME,
  71. WEIGHTS_INDEX_NAME,
  72. WEIGHTS_NAME,
  73. ContextManagers,
  74. ModelOutput,
  75. PushToHubMixin,
  76. cached_file,
  77. copy_func,
  78. download_url,
  79. extract_commit_hash,
  80. has_file,
  81. is_accelerate_available,
  82. is_bitsandbytes_available,
  83. is_flash_attn_2_available,
  84. is_offline_mode,
  85. is_optimum_available,
  86. is_peft_available,
  87. is_remote_url,
  88. is_safetensors_available,
  89. is_torch_sdpa_available,
  90. is_torch_xla_available,
  91. logging,
  92. replace_return_docstrings,
  93. strtobool,
  94. )
  95. from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
  96. from .utils.import_utils import (
  97. ENV_VARS_TRUE_VALUES,
  98. is_sagemaker_mp_enabled,
  99. is_torch_fx_proxy,
  100. is_torchdynamo_compiling,
  101. )
  102. from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
  103. XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
  104. XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
  105. if is_accelerate_available():
  106. from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
  107. from accelerate.hooks import add_hook_to_module
  108. from accelerate.utils import (
  109. check_tied_parameters_on_same_device,
  110. extract_model_from_parallel,
  111. find_tied_parameters,
  112. get_balanced_memory,
  113. get_max_memory,
  114. load_offloaded_weights,
  115. offload_weight,
  116. save_offload_index,
  117. set_module_tensor_to_device,
  118. )
  119. accelerate_version = version.parse(importlib.metadata.version("accelerate"))
  120. if accelerate_version >= version.parse("0.31"):
  121. from accelerate.utils.modeling import get_state_dict_from_offload
  122. if is_safetensors_available():
  123. from safetensors import safe_open
  124. from safetensors.torch import load_file as safe_load_file
  125. from safetensors.torch import save_file as safe_save_file
  126. logger = logging.get_logger(__name__)
  127. _init_weights = True
  128. def is_fsdp_enabled():
  129. return (
  130. torch.distributed.is_available()
  131. and torch.distributed.is_initialized()
  132. and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
  133. and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
  134. )
  135. def is_local_dist_rank_0():
  136. return (
  137. torch.distributed.is_available()
  138. and torch.distributed.is_initialized()
  139. and int(os.environ.get("LOCAL_RANK", -1)) == 0
  140. )
  141. if is_sagemaker_mp_enabled():
  142. import smdistributed.modelparallel.torch as smp
  143. from smdistributed.modelparallel import __version__ as SMP_VERSION
  144. IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
  145. else:
  146. IS_SAGEMAKER_MP_POST_1_10 = False
  147. if is_peft_available():
  148. from .utils import find_adapter_config_file
  149. TORCH_INIT_FUNCTIONS = {
  150. "uniform_": nn.init.uniform_,
  151. "normal_": nn.init.normal_,
  152. "trunc_normal_": nn.init.trunc_normal_,
  153. "constant_": nn.init.constant_,
  154. "xavier_uniform_": nn.init.xavier_uniform_,
  155. "xavier_normal_": nn.init.xavier_normal_,
  156. "kaiming_uniform_": nn.init.kaiming_uniform_,
  157. "kaiming_normal_": nn.init.kaiming_normal_,
  158. "uniform": nn.init.uniform,
  159. "normal": nn.init.normal,
  160. "xavier_uniform": nn.init.xavier_uniform,
  161. "xavier_normal": nn.init.xavier_normal,
  162. "kaiming_uniform": nn.init.kaiming_uniform,
  163. "kaiming_normal": nn.init.kaiming_normal,
  164. }
  165. @contextmanager
  166. def no_init_weights(_enable=True):
  167. """
  168. Context manager to globally disable weight initialization to speed up loading large models.
  169. TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
  170. """
  171. global _init_weights
  172. old_init_weights = _init_weights
  173. if _enable:
  174. _init_weights = False
  175. def _skip_init(*args, **kwargs):
  176. pass
  177. # # Save the original initialization functions
  178. for name, init_func in TORCH_INIT_FUNCTIONS.items():
  179. setattr(torch.nn.init, name, _skip_init)
  180. try:
  181. yield
  182. finally:
  183. _init_weights = old_init_weights
  184. if _enable:
  185. # # Restore the original initialization functions
  186. for name, init_func in TORCH_INIT_FUNCTIONS.items():
  187. setattr(torch.nn.init, name, init_func)
  188. def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
  189. try:
  190. return next(parameter.parameters()).device
  191. except StopIteration:
  192. # For nn.DataParallel compatibility in PyTorch 1.5
  193. def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
  194. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  195. return tuples
  196. gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  197. first_tuple = next(gen)
  198. return first_tuple[1].device
  199. def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
  200. """
  201. Returns the first parameter dtype (can be non-floating) or asserts if none were found.
  202. """
  203. try:
  204. return next(parameter.parameters()).dtype
  205. except StopIteration:
  206. # For nn.DataParallel compatibility in PyTorch > 1.5
  207. def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
  208. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  209. return tuples
  210. gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  211. first_tuple = next(gen)
  212. return first_tuple[1].dtype
  213. def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
  214. """
  215. Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
  216. """
  217. last_dtype = None
  218. for t in parameter.parameters():
  219. last_dtype = t.dtype
  220. if t.is_floating_point():
  221. # Adding fix for https://github.com/pytorch/xla/issues/4152
  222. # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
  223. # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
  224. # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
  225. if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
  226. return torch.bfloat16
  227. if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
  228. if t.dtype == torch.float:
  229. return torch.bfloat16
  230. if t.dtype == torch.double:
  231. return torch.float32
  232. return t.dtype
  233. if last_dtype is not None:
  234. # if no floating dtype was found return whatever the first dtype is
  235. return last_dtype
  236. # For nn.DataParallel compatibility in PyTorch > 1.5
  237. def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
  238. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  239. return tuples
  240. gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  241. last_tuple = None
  242. for tuple in gen:
  243. last_tuple = tuple
  244. if tuple[1].is_floating_point():
  245. return tuple[1].dtype
  246. if last_tuple is not None:
  247. # fallback to the last dtype
  248. return last_tuple[1].dtype
  249. # fallback to buffer dtype
  250. for t in parameter.buffers():
  251. last_dtype = t.dtype
  252. if t.is_floating_point():
  253. return t.dtype
  254. return last_dtype
  255. def get_state_dict_float_dtype(state_dict):
  256. """
  257. Returns the first found floating dtype in `state_dict` or asserts if none were found.
  258. """
  259. for t in state_dict.values():
  260. if t.is_floating_point():
  261. return t.dtype
  262. raise ValueError("couldn't find any floating point dtypes in state_dict")
  263. def get_state_dict_dtype(state_dict):
  264. """
  265. Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
  266. """
  267. for t in state_dict.values():
  268. if t.is_floating_point():
  269. return t.dtype
  270. # if no floating dtype was found return whatever the first dtype is
  271. else:
  272. return next(state_dict.values()).dtype
  273. def dtype_byte_size(dtype):
  274. """
  275. Returns the size (in bytes) occupied by one parameter of type `dtype`.
  276. Example:
  277. ```py
  278. >>> dtype_byte_size(torch.float32)
  279. 4
  280. ```
  281. """
  282. if dtype == torch.bool:
  283. return 1 / 8
  284. bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
  285. if bit_search is None:
  286. raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
  287. bit_size = int(bit_search.groups()[0])
  288. return bit_size // 8
  289. def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
  290. """
  291. Checks if `model_to_load` supports param buffer assignment (such
  292. as when loading in empty weights) by first checking
  293. if the model explicitly disables it, then by ensuring that the state dict keys
  294. are a subset of the model's parameters.
  295. Note: We fully disable this if we are using `deepspeed`
  296. """
  297. if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
  298. return False
  299. if is_deepspeed_zero3_enabled():
  300. return False
  301. # Some models explicitly do not support param buffer assignment
  302. if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
  303. logger.debug(
  304. f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
  305. )
  306. return False
  307. # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
  308. first_key = list(model_to_load.state_dict().keys())[0]
  309. if start_prefix + first_key in state_dict:
  310. return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
  311. # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
  312. return False
  313. def shard_checkpoint(
  314. state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
  315. ):
  316. """
  317. Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
  318. given size.
  319. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
  320. optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
  321. limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
  322. [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
  323. <Tip warning={true}>
  324. If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
  325. have a size greater than `max_shard_size`.
  326. </Tip>
  327. Args:
  328. state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
  329. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  330. The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
  331. (like `"5MB"`).
  332. weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
  333. The name of the model save file.
  334. """
  335. logger.warning(
  336. "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
  337. "split_torch_state_dict_into_shards from huggingface_hub library"
  338. )
  339. max_shard_size = convert_file_size_to_int(max_shard_size)
  340. sharded_state_dicts = [{}]
  341. last_block_size = 0
  342. total_size = 0
  343. storage_id_to_block = {}
  344. for key, weight in state_dict.items():
  345. # when bnb serialization is used the weights in the state dict can be strings
  346. # check: https://github.com/huggingface/transformers/pull/24416 for more details
  347. if isinstance(weight, str):
  348. continue
  349. else:
  350. storage_id = id_tensor_storage(weight)
  351. # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
  352. if storage_id in storage_id_to_block and weight.device != torch.device("meta"):
  353. block_id = storage_id_to_block[storage_id]
  354. sharded_state_dicts[block_id][key] = weight
  355. continue
  356. weight_size = weight.numel() * dtype_byte_size(weight.dtype)
  357. # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
  358. # weight in the current shard.
  359. if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
  360. sharded_state_dicts.append({})
  361. last_block_size = 0
  362. sharded_state_dicts[-1][key] = weight
  363. last_block_size += weight_size
  364. total_size += weight_size
  365. storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
  366. # If we only have one shard, we return it
  367. if len(sharded_state_dicts) == 1:
  368. return {weights_name: sharded_state_dicts[0]}, None
  369. # Otherwise, let's build the index
  370. weight_map = {}
  371. shards = {}
  372. for idx, shard in enumerate(sharded_state_dicts):
  373. shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
  374. shard_file = shard_file.replace(
  375. ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
  376. )
  377. shards[shard_file] = shard
  378. for key in shard.keys():
  379. weight_map[key] = shard_file
  380. # Add the metadata
  381. metadata = {"total_size": total_size}
  382. index = {"metadata": metadata, "weight_map": weight_map}
  383. return shards, index
  384. def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
  385. """
  386. This is the same as
  387. [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
  388. but for a sharded checkpoint.
  389. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  390. loaded in the model.
  391. Args:
  392. model (`torch.nn.Module`): The model in which to load the checkpoint.
  393. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
  394. strict (`bool`, *optional`, defaults to `True`):
  395. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  396. prefer_safe (`bool`, *optional*, defaults to `False`)
  397. If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
  398. safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
  399. Returns:
  400. `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
  401. - `missing_keys` is a list of str containing the missing keys
  402. - `unexpected_keys` is a list of str containing the unexpected keys
  403. """
  404. # Load the index
  405. index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
  406. safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
  407. index_present = os.path.isfile(index_file)
  408. safe_index_present = os.path.isfile(safe_index_file)
  409. if not index_present and not (safe_index_present and is_safetensors_available()):
  410. filenames = (
  411. (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
  412. )
  413. raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
  414. load_safe = False
  415. if safe_index_present:
  416. if prefer_safe:
  417. if is_safetensors_available():
  418. load_safe = True # load safe due to preference
  419. else:
  420. logger.warning(
  421. f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
  422. )
  423. elif not index_present:
  424. load_safe = True # load safe since we have no other choice
  425. load_index = safe_index_file if load_safe else index_file
  426. with open(load_index, "r", encoding="utf-8") as f:
  427. index = json.load(f)
  428. shard_files = list(set(index["weight_map"].values()))
  429. # If strict=True, error before loading any of the state dicts.
  430. loaded_keys = index["weight_map"].keys()
  431. model_keys = model.state_dict().keys()
  432. missing_keys = [key for key in model_keys if key not in loaded_keys]
  433. unexpected_keys = [key for key in loaded_keys if key not in model_keys]
  434. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  435. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  436. if len(missing_keys) > 0:
  437. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  438. error_message += f"\nMissing key(s): {str_missing_keys}."
  439. if len(unexpected_keys) > 0:
  440. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  441. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  442. raise RuntimeError(error_message)
  443. weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
  444. loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
  445. for shard_file in shard_files:
  446. state_dict = loader(os.path.join(folder, shard_file))
  447. model.load_state_dict(state_dict, strict=False)
  448. # Make sure memory is freed before we load the next state dict.
  449. del state_dict
  450. gc.collect()
  451. # Return the same thing as PyTorch load_state_dict function.
  452. return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
  453. def load_state_dict(
  454. checkpoint_file: Union[str, os.PathLike],
  455. is_quantized: bool = False,
  456. map_location: Optional[Union[str, torch.device]] = None,
  457. weights_only: bool = True,
  458. ):
  459. """
  460. Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
  461. """
  462. if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
  463. # Check format of the archive
  464. with safe_open(checkpoint_file, framework="pt") as f:
  465. metadata = f.metadata()
  466. if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  467. raise OSError(
  468. f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
  469. "you save your model with the `save_pretrained` method."
  470. )
  471. return safe_load_file(checkpoint_file)
  472. try:
  473. if map_location is None:
  474. if (
  475. (
  476. is_deepspeed_zero3_enabled()
  477. and torch.distributed.is_initialized()
  478. and torch.distributed.get_rank() > 0
  479. )
  480. or (is_fsdp_enabled() and not is_local_dist_rank_0())
  481. ) and not is_quantized:
  482. map_location = "meta"
  483. else:
  484. map_location = "cpu"
  485. extra_args = {}
  486. # mmap can only be used with files serialized with zipfile-based format.
  487. if (
  488. isinstance(checkpoint_file, str)
  489. and map_location != "meta"
  490. and version.parse(torch.__version__) >= version.parse("2.1.0")
  491. and is_zipfile(checkpoint_file)
  492. ):
  493. extra_args = {"mmap": True}
  494. weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
  495. return torch.load(
  496. checkpoint_file,
  497. map_location=map_location,
  498. **weights_only_kwarg,
  499. **extra_args,
  500. )
  501. except Exception as e:
  502. try:
  503. with open(checkpoint_file) as f:
  504. if f.read(7) == "version":
  505. raise OSError(
  506. "You seem to have cloned a repository without having git-lfs installed. Please install "
  507. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  508. "you cloned."
  509. )
  510. else:
  511. raise ValueError(
  512. f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
  513. "model. Make sure you have saved the model properly."
  514. ) from e
  515. except (UnicodeDecodeError, ValueError):
  516. raise OSError(
  517. f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
  518. f"at '{checkpoint_file}'. "
  519. "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
  520. )
  521. def set_initialized_submodules(model, state_dict_keys):
  522. """
  523. Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
  524. dict.
  525. """
  526. not_initialized_submodules = {}
  527. for module_name, module in model.named_modules():
  528. loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
  529. # When checking if the root module is loaded all state_dict_keys must be used.
  530. if module_name == "":
  531. loaded_keys = set(state_dict_keys)
  532. if loaded_keys.issuperset(module.state_dict()):
  533. module._is_hf_initialized = True
  534. else:
  535. not_initialized_submodules[module_name] = module
  536. return not_initialized_submodules
  537. def _end_ptr(tensor: torch.Tensor) -> int:
  538. # extract the end of the pointer if the tensor is a slice of a bigger tensor
  539. if tensor.nelement():
  540. stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  541. else:
  542. stop = tensor.data_ptr()
  543. return stop
  544. def _get_tied_weight_keys(module: nn.Module, prefix=""):
  545. tied_weight_keys = []
  546. if getattr(module, "_tied_weights_keys", None) is not None:
  547. names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
  548. tied_weight_keys.extend(names)
  549. if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
  550. names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
  551. tied_weight_keys.extend(names)
  552. for name, submodule in module.named_children():
  553. local_prefix = f"{prefix}.{name}" if prefix else name
  554. tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
  555. return tied_weight_keys
  556. def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
  557. filtered_tensors = []
  558. for shared in tensors:
  559. if len(shared) < 2:
  560. filtered_tensors.append(shared)
  561. continue
  562. areas = []
  563. for name in shared:
  564. tensor = state_dict[name]
  565. areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
  566. areas.sort()
  567. _, last_stop, last_name = areas[0]
  568. filtered_tensors.append({last_name})
  569. for start, stop, name in areas[1:]:
  570. if start >= last_stop:
  571. filtered_tensors.append({name})
  572. else:
  573. filtered_tensors[-1].add(name)
  574. last_stop = stop
  575. disjoint_tensors = []
  576. shared_tensors = []
  577. for tensors in filtered_tensors:
  578. if len(tensors) == 1:
  579. disjoint_tensors.append(tensors.pop())
  580. else:
  581. shared_tensors.append(tensors)
  582. return shared_tensors, disjoint_tensors
  583. def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
  584. shared_tensors = []
  585. identical = []
  586. for shared in tensors:
  587. if len(shared) < 2:
  588. continue
  589. areas = collections.defaultdict(set)
  590. for name in shared:
  591. tensor = state_dict[name]
  592. area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
  593. areas[area].add(name)
  594. if len(areas) == 1:
  595. identical.append(shared)
  596. else:
  597. shared_tensors.append(shared)
  598. return shared_tensors, identical
  599. def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
  600. # Convert old format to new format if needed from a PyTorch state_dict
  601. old_keys = []
  602. new_keys = []
  603. renamed_keys = {}
  604. renamed_gamma = {}
  605. renamed_beta = {}
  606. warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
  607. for key in state_dict.keys():
  608. new_key = None
  609. if "gamma" in key:
  610. # We add only the first key as an example
  611. new_key = key.replace("gamma", "weight")
  612. renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
  613. if "beta" in key:
  614. # We add only the first key as an example
  615. new_key = key.replace("beta", "bias")
  616. renamed_beta[key] = new_key if not renamed_beta else renamed_beta
  617. if new_key:
  618. old_keys.append(key)
  619. new_keys.append(new_key)
  620. renamed_keys = {**renamed_gamma, **renamed_beta}
  621. if renamed_keys:
  622. warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
  623. for old_key, new_key in renamed_keys.items():
  624. warning_msg += f"* `{old_key}` -> `{new_key}`\n"
  625. warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
  626. logger.info_once(warning_msg)
  627. for old_key, new_key in zip(old_keys, new_keys):
  628. state_dict[new_key] = state_dict.pop(old_key)
  629. # copy state_dict so _load_from_state_dict can modify it
  630. metadata = getattr(state_dict, "_metadata", None)
  631. state_dict = state_dict.copy()
  632. if metadata is not None:
  633. state_dict._metadata = metadata
  634. error_msgs = []
  635. # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
  636. # so we need to apply the function recursively.
  637. def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
  638. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  639. local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
  640. args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
  641. # Parameters of module and children will start with prefix. We can exit early if there are none in this
  642. # state_dict
  643. if len([key for key in state_dict if key.startswith(prefix)]) > 0:
  644. if is_deepspeed_zero3_enabled():
  645. import deepspeed
  646. # In sharded models, each shard has only part of the full state_dict, so only gather
  647. # parameters that are in the current state_dict.
  648. named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
  649. params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
  650. if len(params_to_gather) > 0:
  651. # because zero3 puts placeholders in model params, this context
  652. # manager gathers (unpartitions) the params of the current layer, then loads from
  653. # the state dict and then re-partitions them again
  654. with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
  655. if torch.distributed.get_rank() == 0:
  656. module._load_from_state_dict(*args)
  657. else:
  658. module._load_from_state_dict(*args)
  659. for name, child in module._modules.items():
  660. if child is not None:
  661. load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
  662. load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
  663. # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
  664. # it's safe to delete it.
  665. del state_dict
  666. return error_msgs
  667. def find_submodule_and_param_name(model, long_key, start_prefix):
  668. """
  669. A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
  670. from the start of the key
  671. """
  672. if len(start_prefix) > 0 and long_key.startswith(start_prefix):
  673. long_key = ".".join(long_key.split(".")[1:])
  674. split_key = long_key.split(".")
  675. submodule = model
  676. while len(split_key) > 1:
  677. if hasattr(submodule, split_key[0]):
  678. submodule = getattr(submodule, split_key[0])
  679. del split_key[0]
  680. else:
  681. submodule = None
  682. break
  683. if submodule == model:
  684. submodule = None
  685. return submodule, split_key[0]
  686. def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
  687. """
  688. Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
  689. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
  690. `bert.pooler.dense.weight`
  691. """
  692. # dematerialize param storage for keys that are going to be replaced by state_dict, by
  693. # putting those on the meta device
  694. for k in loaded_state_dict_keys:
  695. submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
  696. if submodule is not None:
  697. # selectively switch to the meta device only those params/buffers that will
  698. # be next replaced from state_dict. This a complex way to do p.to_("meta")
  699. # since we have no in-place to_ for tensors.
  700. new_val = getattr(submodule, param_name)
  701. if isinstance(new_val, torch.nn.Parameter):
  702. # isinstance returns False for Params on meta device, so switch after the check
  703. new_val = torch.nn.Parameter(new_val.to("meta"))
  704. else:
  705. new_val = new_val.to("meta")
  706. setattr(submodule, param_name, new_val)
  707. def _load_state_dict_into_meta_model(
  708. model,
  709. state_dict,
  710. start_prefix,
  711. expected_keys,
  712. device_map=None,
  713. offload_folder=None,
  714. offload_index=None,
  715. state_dict_folder=None,
  716. state_dict_index=None,
  717. dtype=None,
  718. hf_quantizer=None,
  719. is_safetensors=False,
  720. keep_in_fp32_modules=None,
  721. unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
  722. pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
  723. ):
  724. """
  725. This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
  726. params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
  727. params back to the normal device, but only for `loaded_state_dict_keys`.
  728. `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
  729. `bert.pooler.dense.weight`
  730. """
  731. # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
  732. # - deepspeed zero 3 support
  733. # - need to copy metadata if any - see _load_state_dict_into_model
  734. # - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
  735. error_msgs = []
  736. old_keys = []
  737. new_keys = []
  738. renamed_gamma = {}
  739. renamed_beta = {}
  740. is_quantized = hf_quantizer is not None
  741. warning_msg = f"This model {type(model)}"
  742. for key in state_dict.keys():
  743. new_key = None
  744. if "gamma" in key:
  745. # We add only the first key as an example
  746. new_key = key.replace("gamma", "weight")
  747. renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
  748. if "beta" in key:
  749. # We add only the first key as an example
  750. new_key = key.replace("beta", "bias")
  751. renamed_beta[key] = new_key if not renamed_beta else renamed_beta
  752. # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
  753. if hasattr(nn.utils.parametrizations, "weight_norm"):
  754. if "weight_g" in key:
  755. new_key = key.replace("weight_g", "parametrizations.weight.original0")
  756. if "weight_v" in key:
  757. new_key = key.replace("weight_v", "parametrizations.weight.original1")
  758. else:
  759. if "parametrizations.weight.original0" in key:
  760. new_key = key.replace("parametrizations.weight.original0", "weight_g")
  761. if "parametrizations.weight.original1" in key:
  762. new_key = key.replace("parametrizations.weight.original1", "weight_v")
  763. if new_key:
  764. old_keys.append(key)
  765. new_keys.append(new_key)
  766. renamed_keys = {**renamed_gamma, **renamed_beta}
  767. if renamed_keys:
  768. warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
  769. for old_key, new_key in renamed_keys.items():
  770. warning_msg += f"* `{old_key}` -> `{new_key}`\n"
  771. warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
  772. logger.info_once(warning_msg)
  773. for old_key, new_key in zip(old_keys, new_keys):
  774. state_dict[new_key] = state_dict.pop(old_key)
  775. is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
  776. for param_name, param in state_dict.items():
  777. if param_name not in expected_keys:
  778. continue
  779. if param_name.startswith(start_prefix):
  780. param_name = param_name[len(start_prefix) :]
  781. module_name = param_name
  782. set_module_kwargs = {}
  783. # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
  784. # in int/uint/bool and not cast them.
  785. is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
  786. if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
  787. if (
  788. keep_in_fp32_modules is not None
  789. and any(
  790. module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
  791. )
  792. and dtype == torch.float16
  793. ):
  794. param = param.to(torch.float32)
  795. # For backward compatibility with older versions of `accelerate`
  796. # TODO: @sgugger replace this check with version check at the next `accelerate` release
  797. if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters):
  798. set_module_kwargs["dtype"] = torch.float32
  799. else:
  800. param = param.to(dtype)
  801. # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
  802. # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
  803. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
  804. old_param = model
  805. splits = param_name.split(".")
  806. for split in splits:
  807. # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
  808. old_param = getattr(old_param, split, None)
  809. if old_param is None:
  810. break
  811. if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
  812. old_param = None
  813. if old_param is not None:
  814. if dtype is None:
  815. param = param.to(old_param.dtype)
  816. if old_param.is_contiguous():
  817. param = param.contiguous()
  818. set_module_kwargs["value"] = param
  819. if device_map is None:
  820. param_device = "cpu"
  821. else:
  822. # find next higher level module that is defined in device_map:
  823. # bert.lm_head.weight -> bert.lm_head -> bert -> ''
  824. while len(module_name) > 0 and module_name not in device_map:
  825. module_name = ".".join(module_name.split(".")[:-1])
  826. if module_name == "" and "" not in device_map:
  827. # TODO: group all errors and raise at the end.
  828. raise ValueError(f"{param_name} doesn't have any device set.")
  829. param_device = device_map[module_name]
  830. if param_device == "disk":
  831. if not is_safetensors:
  832. offload_index = offload_weight(param, param_name, offload_folder, offload_index)
  833. elif param_device == "cpu" and state_dict_index is not None:
  834. state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
  835. elif (
  836. not is_quantized
  837. or (not hf_quantizer.requires_parameters_quantization)
  838. or (
  839. not hf_quantizer.check_quantized_param(
  840. model, param, param_name, state_dict, param_device=param_device, device_map=device_map
  841. )
  842. )
  843. ):
  844. if is_fsdp_enabled():
  845. param_device = "cpu" if is_local_dist_rank_0() else "meta"
  846. # For backward compatibility with older versions of `accelerate` and for non-quantized params
  847. set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  848. else:
  849. hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
  850. # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
  851. # and then cast it to CPU to avoid excessive memory usage on each GPU
  852. # in comparison to the sharded model across GPUs.
  853. if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
  854. module, tensor_name = get_module_from_name(model, param_name)
  855. value = getattr(module, tensor_name)
  856. param_to = "cpu"
  857. if is_fsdp_enabled() and not is_local_dist_rank_0():
  858. param_to = "meta"
  859. value = type(value)(value.data.to(param_to), **value.__dict__)
  860. setattr(module, tensor_name, value)
  861. # TODO: consider removing used param_parts from state_dict before return
  862. return error_msgs, offload_index, state_dict_index
  863. def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
  864. if variant is not None:
  865. splits = weights_name.split(".")
  866. splits = splits[:-1] + [variant] + splits[-1:]
  867. weights_name = ".".join(splits)
  868. return weights_name
  869. class ModuleUtilsMixin:
  870. """
  871. A few utilities for `torch.nn.Modules`, to be used as a mixin.
  872. """
  873. @staticmethod
  874. def _hook_rss_memory_pre_forward(module, *args, **kwargs):
  875. try:
  876. import psutil
  877. except ImportError:
  878. raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
  879. process = psutil.Process(os.getpid())
  880. mem = process.memory_info()
  881. module.mem_rss_pre_forward = mem.rss
  882. return None
  883. @staticmethod
  884. def _hook_rss_memory_post_forward(module, *args, **kwargs):
  885. try:
  886. import psutil
  887. except ImportError:
  888. raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
  889. process = psutil.Process(os.getpid())
  890. mem = process.memory_info()
  891. module.mem_rss_post_forward = mem.rss
  892. mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
  893. module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
  894. return None
  895. def add_memory_hooks(self):
  896. """
  897. Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
  898. Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
  899. with `model.reset_memory_hooks_state()`.
  900. """
  901. for module in self.modules():
  902. module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
  903. module.register_forward_hook(self._hook_rss_memory_post_forward)
  904. self.reset_memory_hooks_state()
  905. def reset_memory_hooks_state(self):
  906. """
  907. Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
  908. """
  909. for module in self.modules():
  910. module.mem_rss_diff = 0
  911. module.mem_rss_post_forward = 0
  912. module.mem_rss_pre_forward = 0
  913. @property
  914. def device(self) -> torch.device:
  915. """
  916. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  917. device).
  918. """
  919. return get_parameter_device(self)
  920. @property
  921. def dtype(self) -> torch.dtype:
  922. """
  923. `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
  924. """
  925. return get_parameter_dtype(self)
  926. def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
  927. """
  928. Invert an attention mask (e.g., switches 0. and 1.).
  929. Args:
  930. encoder_attention_mask (`torch.Tensor`): An attention mask.
  931. Returns:
  932. `torch.Tensor`: The inverted attention mask.
  933. """
  934. if encoder_attention_mask.dim() == 3:
  935. encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
  936. if encoder_attention_mask.dim() == 2:
  937. encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
  938. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
  939. # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
  940. # /transformer/transformer_layers.py#L270
  941. # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
  942. # encoder_extended_attention_mask.transpose(-1, -2))
  943. encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  944. encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
  945. return encoder_extended_attention_mask
  946. @staticmethod
  947. def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
  948. if device is not None:
  949. warnings.warn(
  950. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  951. )
  952. else:
  953. device = attention_mask.device
  954. batch_size, seq_length = input_shape
  955. seq_ids = torch.arange(seq_length, device=device)
  956. causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
  957. # in case past_key_values are used we need to add a prefix ones mask to the causal mask
  958. # causal and attention masks must have same type with pytorch version < 1.3
  959. causal_mask = causal_mask.to(attention_mask.dtype)
  960. if causal_mask.shape[1] < attention_mask.shape[1]:
  961. prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
  962. causal_mask = torch.cat(
  963. [
  964. torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
  965. causal_mask,
  966. ],
  967. axis=-1,
  968. )
  969. extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
  970. return extended_attention_mask
  971. def get_extended_attention_mask(
  972. self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
  973. ) -> Tensor:
  974. """
  975. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  976. Arguments:
  977. attention_mask (`torch.Tensor`):
  978. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  979. input_shape (`Tuple[int]`):
  980. The shape of the input to the model.
  981. Returns:
  982. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  983. """
  984. if dtype is None:
  985. dtype = self.dtype
  986. if not (attention_mask.dim() == 2 and self.config.is_decoder):
  987. # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
  988. if device is not None:
  989. warnings.warn(
  990. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  991. )
  992. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  993. # ourselves in which case we just need to make it broadcastable to all heads.
  994. if attention_mask.dim() == 3:
  995. extended_attention_mask = attention_mask[:, None, :, :]
  996. elif attention_mask.dim() == 2:
  997. # Provided a padding mask of dimensions [batch_size, seq_length]
  998. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  999. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1000. if self.config.is_decoder:
  1001. extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
  1002. input_shape, attention_mask, device
  1003. )
  1004. else:
  1005. extended_attention_mask = attention_mask[:, None, None, :]
  1006. else:
  1007. raise ValueError(
  1008. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
  1009. )
  1010. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  1011. # masked positions, this operation will create a tensor which is 0.0 for
  1012. # positions we want to attend and the dtype's smallest value for masked positions.
  1013. # Since we are adding it to the raw scores before the softmax, this is
  1014. # effectively the same as removing these entirely.
  1015. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
  1016. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
  1017. return extended_attention_mask
  1018. def get_head_mask(
  1019. self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
  1020. ) -> Tensor:
  1021. """
  1022. Prepare the head mask if needed.
  1023. Args:
  1024. head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  1025. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
  1026. num_hidden_layers (`int`):
  1027. The number of hidden layers in the model.
  1028. is_attention_chunked (`bool`, *optional*, defaults to `False`):
  1029. Whether or not the attentions scores are computed by chunks or not.
  1030. Returns:
  1031. `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
  1032. `[None]` for each layer.
  1033. """
  1034. if head_mask is not None:
  1035. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  1036. if is_attention_chunked is True:
  1037. head_mask = head_mask.unsqueeze(-1)
  1038. else:
  1039. head_mask = [None] * num_hidden_layers
  1040. return head_mask
  1041. def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
  1042. """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
  1043. if head_mask.dim() == 1:
  1044. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  1045. head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
  1046. elif head_mask.dim() == 2:
  1047. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
  1048. assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
  1049. head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
  1050. return head_mask
  1051. def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
  1052. """
  1053. Get number of (optionally, trainable or non-embeddings) parameters in the module.
  1054. Args:
  1055. only_trainable (`bool`, *optional*, defaults to `False`):
  1056. Whether or not to return only the number of trainable parameters
  1057. exclude_embeddings (`bool`, *optional*, defaults to `False`):
  1058. Whether or not to return only the number of non-embeddings parameters
  1059. Returns:
  1060. `int`: The number of parameters.
  1061. """
  1062. if exclude_embeddings:
  1063. embedding_param_names = [
  1064. f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
  1065. ]
  1066. total_parameters = [
  1067. parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
  1068. ]
  1069. else:
  1070. total_parameters = list(self.parameters())
  1071. total_numel = []
  1072. is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
  1073. if is_loaded_in_4bit:
  1074. if is_bitsandbytes_available():
  1075. import bitsandbytes as bnb
  1076. else:
  1077. raise ValueError(
  1078. "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
  1079. " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
  1080. )
  1081. for param in total_parameters:
  1082. if param.requires_grad or not only_trainable:
  1083. # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
  1084. # used for the 4bit quantization (uint8 tensors are stored)
  1085. if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
  1086. if hasattr(param, "element_size"):
  1087. num_bytes = param.element_size()
  1088. elif hasattr(param, "quant_storage"):
  1089. num_bytes = param.quant_storage.itemsize
  1090. else:
  1091. num_bytes = 1
  1092. total_numel.append(param.numel() * 2 * num_bytes)
  1093. else:
  1094. total_numel.append(param.numel())
  1095. return sum(total_numel)
  1096. def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
  1097. """
  1098. Helper function to estimate the total number of tokens from the model inputs.
  1099. Args:
  1100. inputs (`dict`): The model inputs.
  1101. Returns:
  1102. `int`: The total number of tokens.
  1103. """
  1104. if not hasattr(self, "warnings_issued"):
  1105. self.warnings_issued = {}
  1106. if self.main_input_name in input_dict:
  1107. return input_dict[self.main_input_name].numel()
  1108. elif "estimate_tokens" not in self.warnings_issued:
  1109. logger.warning(
  1110. "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
  1111. )
  1112. self.warnings_issued["estimate_tokens"] = True
  1113. return 0
  1114. def floating_point_ops(
  1115. self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
  1116. ) -> int:
  1117. """
  1118. Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
  1119. batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
  1120. tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
  1121. paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter
  1122. re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.
  1123. Args:
  1124. batch_size (`int`):
  1125. The batch size for the forward pass.
  1126. sequence_length (`int`):
  1127. The number of tokens in each line of the batch.
  1128. exclude_embeddings (`bool`, *optional*, defaults to `True`):
  1129. Whether or not to count embedding and softmax operations.
  1130. Returns:
  1131. `int`: The number of floating-point operations.
  1132. """
  1133. return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
  1134. # TODO (joao): remove `GenerationMixin` inheritance in v4.50
  1135. class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
  1136. r"""
  1137. Base class for all models.
  1138. [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
  1139. downloading and saving models as well as a few methods common to all models to:
  1140. - resize the input embeddings,
  1141. - prune heads in the self-attention heads.
  1142. Class attributes (overridden by derived classes):
  1143. - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
  1144. for this model architecture.
  1145. - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
  1146. taking as arguments:
  1147. - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
  1148. - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
  1149. - **path** (`str`) -- A path to the TensorFlow checkpoint.
  1150. - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  1151. classes of the same architecture adding modules on top of the base model.
  1152. - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
  1153. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  1154. models, `pixel_values` for vision models and `input_values` for speech models).
  1155. """
  1156. config_class = None
  1157. base_model_prefix = ""
  1158. main_input_name = "input_ids"
  1159. model_tags = None
  1160. _auto_class = None
  1161. _no_split_modules = None
  1162. _skip_keys_device_placement = None
  1163. _keep_in_fp32_modules = None
  1164. # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
  1165. # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
  1166. _keys_to_ignore_on_load_missing = None
  1167. # a list of `re` patterns of `state_dict` keys that should be removed from the list of
  1168. # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
  1169. # warnings.
  1170. _keys_to_ignore_on_load_unexpected = None
  1171. # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
  1172. # trained, but which are either deterministic or tied variables)
  1173. _keys_to_ignore_on_save = None
  1174. # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
  1175. _tied_weights_keys = None
  1176. is_parallelizable = False
  1177. supports_gradient_checkpointing = False
  1178. _is_stateful = False
  1179. # Flash Attention 2 support
  1180. _supports_flash_attn_2 = False
  1181. # SDPA support
  1182. _supports_sdpa = False
  1183. # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
  1184. _supports_cache_class = False
  1185. _supports_static_cache = False
  1186. # Has support for a `QuantoQuantizedCache` instance as `past_key_values`
  1187. _supports_quantized_cache = False
  1188. @property
  1189. def dummy_inputs(self) -> Dict[str, torch.Tensor]:
  1190. """
  1191. `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
  1192. """
  1193. return {"input_ids": torch.tensor(DUMMY_INPUTS)}
  1194. @property
  1195. def framework(self) -> str:
  1196. """
  1197. :str: Identifies that this is a PyTorch model.
  1198. """
  1199. return "pt"
  1200. def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
  1201. super().__init__()
  1202. if not isinstance(config, PretrainedConfig):
  1203. raise ValueError(
  1204. f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
  1205. "`PretrainedConfig`. To create a model from a pretrained model use "
  1206. f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
  1207. )
  1208. # Save config and origin of the pretrained weights if given in model
  1209. if not getattr(config, "_attn_implementation_autoset", False):
  1210. config = self._autoset_attn_implementation(
  1211. config, torch_dtype=torch.get_default_dtype(), check_device_map=False
  1212. )
  1213. self.config = config
  1214. self.name_or_path = config.name_or_path
  1215. self.warnings_issued = {}
  1216. self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
  1217. # Overwrite the class attribute to make it an instance attribute, so models like
  1218. # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
  1219. # when a different component (e.g. language_model) is used.
  1220. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
  1221. def post_init(self):
  1222. """
  1223. A method executed at the end of each Transformer model initialization, to execute code that needs the model's
  1224. modules properly initialized (such as weight initialization).
  1225. """
  1226. self.init_weights()
  1227. self._backward_compatibility_gradient_checkpointing()
  1228. def dequantize(self):
  1229. """
  1230. Potentially dequantize the model in case it has been quantized by a quantization method that support
  1231. dequantization.
  1232. """
  1233. hf_quantizer = getattr(self, "hf_quantizer", None)
  1234. if hf_quantizer is None:
  1235. raise ValueError("You need to first quantize your model in order to dequantize it")
  1236. return hf_quantizer.dequantize(self)
  1237. def _backward_compatibility_gradient_checkpointing(self):
  1238. if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
  1239. self.gradient_checkpointing_enable()
  1240. # Remove the attribute now that is has been consumed, so it's no saved in the config.
  1241. delattr(self.config, "gradient_checkpointing")
  1242. def add_model_tags(self, tags: Union[List[str], str]) -> None:
  1243. r"""
  1244. Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
  1245. not overwrite existing tags in the model.
  1246. Args:
  1247. tags (`Union[List[str], str]`):
  1248. The desired tags to inject in the model
  1249. Examples:
  1250. ```python
  1251. from transformers import AutoModel
  1252. model = AutoModel.from_pretrained("google-bert/bert-base-cased")
  1253. model.add_model_tags(["custom", "custom-bert"])
  1254. # Push the model to your namespace with the name "my-custom-bert".
  1255. model.push_to_hub("my-custom-bert")
  1256. ```
  1257. """
  1258. if isinstance(tags, str):
  1259. tags = [tags]
  1260. if self.model_tags is None:
  1261. self.model_tags = []
  1262. for tag in tags:
  1263. if tag not in self.model_tags:
  1264. self.model_tags.append(tag)
  1265. @classmethod
  1266. def _from_config(cls, config, **kwargs):
  1267. """
  1268. All context managers that the model should be initialized under go here.
  1269. Args:
  1270. torch_dtype (`torch.dtype`, *optional*):
  1271. Override the default `torch.dtype` and load the model under this dtype.
  1272. """
  1273. # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
  1274. # a warning is raised that dtype should be fp16. Since we never pass dtype from within
  1275. # modeling code, we can try to infer it here same way as done in `from_pretrained`
  1276. torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
  1277. use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
  1278. # override default dtype if needed
  1279. dtype_orig = None
  1280. if torch_dtype is not None:
  1281. dtype_orig = cls._set_default_torch_dtype(torch_dtype)
  1282. config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
  1283. if config._attn_implementation_internal is not None:
  1284. # In this case, the config has been created with the attn_implementation set by the user, which we
  1285. # should respect.
  1286. attn_implementation = config._attn_implementation_internal
  1287. else:
  1288. attn_implementation = None
  1289. config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
  1290. if not getattr(config, "_attn_implementation_autoset", False):
  1291. config = cls._autoset_attn_implementation(
  1292. config,
  1293. use_flash_attention_2=use_flash_attention_2,
  1294. check_device_map=False,
  1295. torch_dtype=torch_dtype,
  1296. )
  1297. if is_deepspeed_zero3_enabled():
  1298. import deepspeed
  1299. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  1300. # this immediately partitions the model across all gpus, to avoid the overhead in time
  1301. # and memory copying it on CPU or each GPU first
  1302. with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
  1303. model = cls(config, **kwargs)
  1304. else:
  1305. model = cls(config, **kwargs)
  1306. # restore default dtype if it was modified
  1307. if dtype_orig is not None:
  1308. torch.set_default_dtype(dtype_orig)
  1309. return model
  1310. @classmethod
  1311. def _autoset_attn_implementation(
  1312. cls,
  1313. config,
  1314. use_flash_attention_2: bool = False,
  1315. torch_dtype: Optional[torch.dtype] = None,
  1316. device_map: Optional[Union[str, Dict[str, int]]] = None,
  1317. check_device_map: bool = True,
  1318. ):
  1319. """
  1320. Automatically checks and dispatches to a default attention implementation. In order of priority:
  1321. 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
  1322. 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
  1323. 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
  1324. 4. The default model's implementation otherwise (`LlamaAttention` for example) .
  1325. """
  1326. # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
  1327. # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
  1328. # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
  1329. requested_attn_implementation = None
  1330. if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
  1331. if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
  1332. raise ValueError(
  1333. f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
  1334. ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
  1335. )
  1336. if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
  1337. "eager",
  1338. "sdpa",
  1339. "flash_attention_2",
  1340. ]:
  1341. message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
  1342. if cls._supports_flash_attn_2:
  1343. message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
  1344. if cls._supports_sdpa:
  1345. message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
  1346. raise ValueError(message + ".")
  1347. # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
  1348. requested_attn_implementation = config._attn_implementation_internal
  1349. # Composite models consisting of several PretrainedModels have to specify attention impl as a dict
  1350. # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
  1351. # for all sub-models.
  1352. # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
  1353. # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
  1354. # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
  1355. for key in config:
  1356. if isinstance(getattr(config, key), PretrainedConfig):
  1357. sub_config = getattr(config, key)
  1358. curr_attn_implementation = (
  1359. requested_attn_implementation
  1360. if not isinstance(requested_attn_implementation, dict)
  1361. else requested_attn_implementation.get(key, None)
  1362. )
  1363. sub_config._attn_implementation_internal = curr_attn_implementation
  1364. if use_flash_attention_2:
  1365. logger.warning_once(
  1366. 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
  1367. )
  1368. config._attn_implementation = "flash_attention_2"
  1369. if config._attn_implementation == "flash_attention_2":
  1370. cls._check_and_enable_flash_attn_2(
  1371. config,
  1372. torch_dtype=torch_dtype,
  1373. device_map=device_map,
  1374. hard_check_only=False,
  1375. check_device_map=check_device_map,
  1376. )
  1377. elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
  1378. # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
  1379. config = cls._check_and_enable_sdpa(
  1380. config,
  1381. hard_check_only=False if requested_attn_implementation is None else True,
  1382. )
  1383. if (
  1384. torch.version.hip is not None
  1385. and config._attn_implementation == "sdpa"
  1386. and torch.cuda.device_count() > 1
  1387. ):
  1388. logger.warning_once(
  1389. "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
  1390. )
  1391. torch.backends.cuda.enable_flash_sdp(False)
  1392. elif isinstance(requested_attn_implementation, dict):
  1393. config._attn_implementation = None
  1394. else:
  1395. config._attn_implementation = "eager"
  1396. config._attn_implementation_autoset = True
  1397. return config
  1398. @classmethod
  1399. def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
  1400. """
  1401. Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
  1402. under specific dtype.
  1403. Args:
  1404. dtype (`torch.dtype`):
  1405. a floating dtype to set to.
  1406. Returns:
  1407. `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
  1408. modified. If it wasn't, returns `None`.
  1409. Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
  1410. `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
  1411. """
  1412. if not dtype.is_floating_point:
  1413. raise ValueError(
  1414. f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
  1415. )
  1416. logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
  1417. dtype_orig = torch.get_default_dtype()
  1418. torch.set_default_dtype(dtype)
  1419. return dtype_orig
  1420. @property
  1421. def base_model(self) -> nn.Module:
  1422. """
  1423. `torch.nn.Module`: The main body of the model.
  1424. """
  1425. return getattr(self, self.base_model_prefix, self)
  1426. @classmethod
  1427. def can_generate(cls) -> bool:
  1428. """
  1429. Returns whether this model can generate sequences with `.generate()`.
  1430. Returns:
  1431. `bool`: Whether this model can generate sequences with `.generate()`.
  1432. """
  1433. # Directly inherits `GenerationMixin` -> can generate
  1434. if "GenerationMixin" in str(cls.__bases__):
  1435. return True
  1436. # Model class overwrites `generate` (e.g. time series models) -> can generate
  1437. if str(cls.__name__) in str(cls.generate):
  1438. return True
  1439. # The class inherits from a class that can generate (recursive check) -> can generate
  1440. for base in cls.__bases__:
  1441. if not hasattr(base, "can_generate"):
  1442. continue
  1443. if "PreTrainedModel" not in str(base) and base.can_generate():
  1444. return True
  1445. # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
  1446. # was how we detected whether a model could generate.
  1447. if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
  1448. logger.warning_once(
  1449. f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
  1450. "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
  1451. "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
  1452. "to call `generate` and other related functions."
  1453. "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
  1454. "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
  1455. "\n - If you are the owner of the model architecture code, please modify your model class such that "
  1456. "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
  1457. "\n - If you are not the owner of the model architecture class, please contact the model code owner "
  1458. "to update it."
  1459. )
  1460. return True
  1461. # Otherwise, can't generate
  1462. return False
  1463. @classmethod
  1464. def _check_and_enable_flash_attn_2(
  1465. cls,
  1466. config,
  1467. torch_dtype: Optional[torch.dtype] = None,
  1468. device_map: Optional[Union[str, Dict[str, int]]] = None,
  1469. check_device_map: bool = True,
  1470. hard_check_only: bool = False,
  1471. ) -> PretrainedConfig:
  1472. """
  1473. Checks the availability of Flash Attention 2 and compatibility with the current model.
  1474. If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
  1475. """
  1476. if not cls._supports_flash_attn_2:
  1477. raise ValueError(
  1478. f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
  1479. f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
  1480. " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
  1481. )
  1482. if not is_flash_attn_2_available():
  1483. preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
  1484. install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
  1485. if importlib.util.find_spec("flash_attn") is None:
  1486. raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
  1487. flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
  1488. if torch.version.cuda:
  1489. if flash_attention_version < version.parse("2.1.0"):
  1490. raise ImportError(
  1491. f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
  1492. )
  1493. elif not torch.cuda.is_available():
  1494. raise ValueError(
  1495. f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
  1496. )
  1497. else:
  1498. raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
  1499. elif torch.version.hip:
  1500. if flash_attention_version < version.parse("2.0.4"):
  1501. raise ImportError(
  1502. f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}"
  1503. )
  1504. else:
  1505. raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
  1506. _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
  1507. if _is_bettertransformer:
  1508. raise ValueError(
  1509. "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
  1510. )
  1511. if torch_dtype is None:
  1512. logger.warning_once(
  1513. "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour"
  1514. )
  1515. elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
  1516. logger.warning_once(
  1517. "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
  1518. f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
  1519. ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`'
  1520. )
  1521. # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
  1522. # or the model may be initialized under the context manager `with torch.device("cuda"):`.
  1523. if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
  1524. if torch.cuda.is_available():
  1525. logger.warning_once(
  1526. "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU"
  1527. " after initializing it on CPU with `model.to('cuda')`."
  1528. )
  1529. else:
  1530. raise ValueError(
  1531. "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. "
  1532. "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
  1533. "or initialising the model on CPU and then moving it to GPU."
  1534. )
  1535. elif (
  1536. check_device_map
  1537. and device_map is not None
  1538. and isinstance(device_map, dict)
  1539. and ("cpu" in device_map.values() or "disk" in device_map.values())
  1540. ):
  1541. raise ValueError(
  1542. "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
  1543. "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
  1544. )
  1545. if not hard_check_only:
  1546. config._attn_implementation = "flash_attention_2"
  1547. return config
  1548. @classmethod
  1549. def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
  1550. """
  1551. Checks the availability of SDPA for a given model.
  1552. If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
  1553. """
  1554. if hard_check_only:
  1555. if not cls._supports_sdpa:
  1556. raise ValueError(
  1557. f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
  1558. " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
  1559. ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
  1560. )
  1561. if not is_torch_sdpa_available():
  1562. raise ImportError(
  1563. "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
  1564. )
  1565. if not is_torch_sdpa_available() or not cls._supports_sdpa:
  1566. return config
  1567. _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
  1568. if _is_bettertransformer:
  1569. return config
  1570. if not hard_check_only:
  1571. config._attn_implementation = "sdpa"
  1572. return config
  1573. def enable_input_require_grads(self):
  1574. """
  1575. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  1576. the model weights fixed.
  1577. """
  1578. def make_inputs_require_grads(module, input, output):
  1579. output.requires_grad_(True)
  1580. self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  1581. def disable_input_require_grads(self):
  1582. """
  1583. Removes the `_require_grads_hook`.
  1584. """
  1585. self._require_grads_hook.remove()
  1586. def get_input_embeddings(self) -> nn.Module:
  1587. """
  1588. Returns the model's input embeddings.
  1589. Returns:
  1590. `nn.Module`: A torch module mapping vocabulary to hidden states.
  1591. """
  1592. base_model = getattr(self, self.base_model_prefix, self)
  1593. if base_model is not self:
  1594. return base_model.get_input_embeddings()
  1595. else:
  1596. raise NotImplementedError
  1597. def set_input_embeddings(self, value: nn.Module):
  1598. """
  1599. Set model's input embeddings.
  1600. Args:
  1601. value (`nn.Module`): A module mapping vocabulary to hidden states.
  1602. """
  1603. base_model = getattr(self, self.base_model_prefix, self)
  1604. if base_model is not self:
  1605. base_model.set_input_embeddings(value)
  1606. else:
  1607. raise NotImplementedError
  1608. def get_output_embeddings(self) -> nn.Module:
  1609. """
  1610. Returns the model's output embeddings.
  1611. Returns:
  1612. `nn.Module`: A torch module mapping hidden states to vocabulary.
  1613. """
  1614. return None # Overwrite for models with output embeddings
  1615. def _init_weights(self, module):
  1616. """
  1617. Initialize the weights. This method should be overridden by derived class and is
  1618. the only initialization method that will be called when loading a checkpoint
  1619. using `from_pretrained`. Any attempt to initialize outside of this function
  1620. will be useless as the torch.nn.init function are all replaced with skip.
  1621. """
  1622. pass
  1623. def _initialize_weights(self, module):
  1624. """
  1625. Initialize the weights if they are not already initialized.
  1626. """
  1627. if getattr(module, "_is_hf_initialized", False):
  1628. return
  1629. self._init_weights(module)
  1630. module._is_hf_initialized = True
  1631. def tie_weights(self):
  1632. """
  1633. Tie the weights between the input embeddings and the output embeddings.
  1634. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  1635. weights instead.
  1636. """
  1637. if getattr(self.config, "tie_word_embeddings", True):
  1638. output_embeddings = self.get_output_embeddings()
  1639. if output_embeddings is not None:
  1640. self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
  1641. if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
  1642. if hasattr(self, self.base_model_prefix):
  1643. self = getattr(self, self.base_model_prefix)
  1644. tied_weights = self._tie_encoder_decoder_weights(
  1645. self.encoder, self.decoder, self.base_model_prefix, "encoder"
  1646. )
  1647. # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
  1648. # attributed not an instance member, therefore modifying it will modify the entire class
  1649. # Leading to issues on subsequent calls by different tests or subsequent calls.
  1650. self._dynamic_tied_weights_keys = tied_weights
  1651. for module in self.modules():
  1652. if hasattr(module, "_tie_weights"):
  1653. module._tie_weights()
  1654. @staticmethod
  1655. def _tie_encoder_decoder_weights(
  1656. encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
  1657. ):
  1658. uninitialized_encoder_weights: List[str] = []
  1659. tied_weights: List[str] = []
  1660. if decoder.__class__ != encoder.__class__:
  1661. logger.info(
  1662. f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
  1663. " weights are correctly initialized."
  1664. )
  1665. def tie_encoder_to_decoder_recursively(
  1666. decoder_pointer: nn.Module,
  1667. encoder_pointer: nn.Module,
  1668. module_name: str,
  1669. base_encoder_name: str,
  1670. uninitialized_encoder_weights: List[str],
  1671. depth=0,
  1672. total_decoder_name="",
  1673. total_encoder_name="",
  1674. ):
  1675. assert isinstance(decoder_pointer, nn.Module) and isinstance(
  1676. encoder_pointer, nn.Module
  1677. ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
  1678. if hasattr(decoder_pointer, "weight"):
  1679. assert hasattr(encoder_pointer, "weight")
  1680. encoder_pointer.weight = decoder_pointer.weight
  1681. tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
  1682. if hasattr(decoder_pointer, "bias"):
  1683. assert hasattr(encoder_pointer, "bias")
  1684. tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
  1685. encoder_pointer.bias = decoder_pointer.bias
  1686. return
  1687. encoder_modules = encoder_pointer._modules
  1688. decoder_modules = decoder_pointer._modules
  1689. if len(decoder_modules) > 0:
  1690. assert (
  1691. len(encoder_modules) > 0
  1692. ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
  1693. all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
  1694. encoder_layer_pos = 0
  1695. for name, module in decoder_modules.items():
  1696. if name.isdigit():
  1697. encoder_name = str(int(name) + encoder_layer_pos)
  1698. decoder_name = name
  1699. if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
  1700. encoder_modules
  1701. ) != len(decoder_modules):
  1702. # this can happen if the name corresponds to the position in a list module list of layers
  1703. # in this case the decoder has added a cross-attention that the encoder does not have
  1704. # thus skip this step and subtract one layer pos from encoder
  1705. encoder_layer_pos -= 1
  1706. continue
  1707. elif name not in encoder_modules:
  1708. continue
  1709. elif depth > 500:
  1710. raise ValueError(
  1711. "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is"
  1712. " a circular dependency between two or more `nn.Modules` of your model."
  1713. )
  1714. else:
  1715. decoder_name = encoder_name = name
  1716. tie_encoder_to_decoder_recursively(
  1717. decoder_modules[decoder_name],
  1718. encoder_modules[encoder_name],
  1719. module_name + "/" + name,
  1720. base_encoder_name,
  1721. uninitialized_encoder_weights,
  1722. depth=depth + 1,
  1723. total_encoder_name=f"{total_encoder_name}.{encoder_name}",
  1724. total_decoder_name=f"{total_decoder_name}.{decoder_name}",
  1725. )
  1726. all_encoder_weights.remove(module_name + "/" + encoder_name)
  1727. uninitialized_encoder_weights += list(all_encoder_weights)
  1728. # tie weights recursively
  1729. tie_encoder_to_decoder_recursively(
  1730. decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
  1731. )
  1732. if len(uninitialized_encoder_weights) > 0:
  1733. logger.warning(
  1734. f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
  1735. )
  1736. return tied_weights
  1737. def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
  1738. """Tie or clone module weights depending of whether we are using TorchScript or not"""
  1739. if self.config.torchscript:
  1740. output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
  1741. else:
  1742. output_embeddings.weight = input_embeddings.weight
  1743. if getattr(output_embeddings, "bias", None) is not None:
  1744. output_embeddings.bias.data = nn.functional.pad(
  1745. output_embeddings.bias.data,
  1746. (
  1747. 0,
  1748. output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
  1749. ),
  1750. "constant",
  1751. 0,
  1752. )
  1753. if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
  1754. output_embeddings.out_features = input_embeddings.num_embeddings
  1755. def _get_no_split_modules(self, device_map: str):
  1756. """
  1757. Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
  1758. get the underlying `_no_split_modules`.
  1759. Args:
  1760. device_map (`str`):
  1761. The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
  1762. Returns:
  1763. `List[str]`: List of modules that should not be split
  1764. """
  1765. _no_split_modules = set()
  1766. modules_to_check = [self]
  1767. while len(modules_to_check) > 0:
  1768. module = modules_to_check.pop(-1)
  1769. # if the module does not appear in _no_split_modules, we also check the children
  1770. if module.__class__.__name__ not in _no_split_modules:
  1771. if isinstance(module, PreTrainedModel):
  1772. if module._no_split_modules is None:
  1773. raise ValueError(
  1774. f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
  1775. "class needs to implement the `_no_split_modules` attribute."
  1776. )
  1777. else:
  1778. _no_split_modules = _no_split_modules | set(module._no_split_modules)
  1779. modules_to_check += list(module.children())
  1780. return list(_no_split_modules)
  1781. def resize_token_embeddings(
  1782. self,
  1783. new_num_tokens: Optional[int] = None,
  1784. pad_to_multiple_of: Optional[int] = None,
  1785. mean_resizing: bool = True,
  1786. ) -> nn.Embedding:
  1787. """
  1788. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  1789. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  1790. Arguments:
  1791. new_num_tokens (`int`, *optional*):
  1792. The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
  1793. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1794. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  1795. pad_to_multiple_of (`int`, *optional*):
  1796. If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
  1797. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  1798. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  1799. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  1800. details about this, or help on choosing the correct value for resizing, refer to this guide:
  1801. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  1802. mean_resizing (`bool`):
  1803. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  1804. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  1805. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  1806. where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
  1807. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  1808. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  1809. Return:
  1810. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  1811. """
  1812. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1813. if new_num_tokens is None and pad_to_multiple_of is None:
  1814. return model_embeds
  1815. # Since we are basically resuing the same old embeddings with new weight values, gathering is required
  1816. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  1817. if is_deepspeed_zero3_enabled() and not is_quantized:
  1818. import deepspeed
  1819. with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
  1820. vocab_size = model_embeds.weight.shape[0]
  1821. else:
  1822. vocab_size = model_embeds.weight.shape[0]
  1823. # Update base model and current model config.
  1824. self.config.get_text_config().vocab_size = vocab_size
  1825. self.vocab_size = vocab_size
  1826. # Tie weights again if needed
  1827. self.tie_weights()
  1828. return model_embeds
  1829. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
  1830. old_embeddings = self.get_input_embeddings()
  1831. new_embeddings = self._get_resized_embeddings(
  1832. old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
  1833. )
  1834. if hasattr(old_embeddings, "_hf_hook"):
  1835. hook = old_embeddings._hf_hook
  1836. add_hook_to_module(new_embeddings, hook)
  1837. old_embeddings_requires_grad = old_embeddings.weight.requires_grad
  1838. new_embeddings.requires_grad_(old_embeddings_requires_grad)
  1839. self.set_input_embeddings(new_embeddings)
  1840. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  1841. # Update new_num_tokens with the actual size of new_embeddings
  1842. if pad_to_multiple_of is not None:
  1843. if is_deepspeed_zero3_enabled() and not is_quantized:
  1844. import deepspeed
  1845. with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
  1846. new_num_tokens = new_embeddings.weight.shape[0]
  1847. else:
  1848. new_num_tokens = new_embeddings.weight.shape[0]
  1849. # if word embeddings are not tied, make sure that lm head is resized as well
  1850. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  1851. old_lm_head = self.get_output_embeddings()
  1852. if isinstance(old_lm_head, torch.nn.Embedding):
  1853. new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  1854. else:
  1855. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  1856. if hasattr(old_lm_head, "_hf_hook"):
  1857. hook = old_lm_head._hf_hook
  1858. add_hook_to_module(new_lm_head, hook)
  1859. old_lm_head_requires_grad = old_lm_head.weight.requires_grad
  1860. new_lm_head.requires_grad_(old_lm_head_requires_grad)
  1861. self.set_output_embeddings(new_lm_head)
  1862. return self.get_input_embeddings()
  1863. def _get_resized_embeddings(
  1864. self,
  1865. old_embeddings: nn.Embedding,
  1866. new_num_tokens: Optional[int] = None,
  1867. pad_to_multiple_of: Optional[int] = None,
  1868. mean_resizing: bool = True,
  1869. ) -> nn.Embedding:
  1870. """
  1871. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  1872. initialized vectors at the end. Reducing the size will remove vectors from the end
  1873. Args:
  1874. old_embeddings (`torch.nn.Embedding`):
  1875. Old embeddings to be resized.
  1876. new_num_tokens (`int`, *optional*):
  1877. New number of tokens in the embedding matrix.
  1878. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1879. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  1880. `torch.nn.Embedding` module of the model without doing anything.
  1881. pad_to_multiple_of (`int`, *optional*):
  1882. If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
  1883. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  1884. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  1885. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  1886. details about this, or help on choosing the correct value for resizing, refer to this guide:
  1887. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  1888. mean_resizing (`bool`):
  1889. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  1890. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  1891. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  1892. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  1893. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  1894. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  1895. Return:
  1896. `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
  1897. `new_num_tokens` is `None`
  1898. """
  1899. if pad_to_multiple_of is not None:
  1900. if not isinstance(pad_to_multiple_of, int):
  1901. raise ValueError(
  1902. f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
  1903. )
  1904. if new_num_tokens is None:
  1905. new_num_tokens = old_embeddings.weight.shape[0]
  1906. new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
  1907. else:
  1908. logger.info(
  1909. "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
  1910. f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
  1911. " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
  1912. " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
  1913. )
  1914. if new_num_tokens is None:
  1915. return old_embeddings
  1916. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  1917. if is_deepspeed_zero3_enabled() and not is_quantized:
  1918. import deepspeed
  1919. with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
  1920. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  1921. else:
  1922. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  1923. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  1924. return old_embeddings
  1925. if not isinstance(old_embeddings, nn.Embedding):
  1926. raise TypeError(
  1927. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  1928. " should either use a different resize function or make sure that `old_embeddings` are an instance of"
  1929. f" {nn.Embedding}."
  1930. )
  1931. # Build new embeddings
  1932. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  1933. # because the shape of the new embedding layer is used across various modeling files
  1934. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  1935. # to errors when training.
  1936. new_embeddings = nn.Embedding(
  1937. new_num_tokens,
  1938. old_embedding_dim,
  1939. device=old_embeddings.weight.device,
  1940. dtype=old_embeddings.weight.dtype,
  1941. )
  1942. if new_num_tokens > old_num_tokens and not mean_resizing:
  1943. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  1944. self._init_weights(new_embeddings)
  1945. elif new_num_tokens > old_num_tokens and mean_resizing:
  1946. # initialize new embeddings (in particular added tokens). The new embeddings will be initialized
  1947. # from a multivariate normal distribution that has old embeddings' mean and covariance.
  1948. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  1949. logger.warning_once(
  1950. "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  1951. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  1952. "To disable this, use `mean_resizing=False`"
  1953. )
  1954. added_num_tokens = new_num_tokens - old_num_tokens
  1955. if is_deepspeed_zero3_enabled() and not is_quantized:
  1956. import deepspeed
  1957. with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
  1958. self._init_added_embeddings_weights_with_mean(
  1959. old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  1960. )
  1961. else:
  1962. self._init_added_embeddings_weights_with_mean(
  1963. old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  1964. )
  1965. # Copy token embeddings from the previous weights
  1966. # numbers of tokens to copy
  1967. n = min(old_num_tokens, new_num_tokens)
  1968. if is_deepspeed_zero3_enabled() and not is_quantized:
  1969. import deepspeed
  1970. params = [old_embeddings.weight, new_embeddings.weight]
  1971. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  1972. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  1973. else:
  1974. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  1975. # Replace weights in old_embeddings and return to maintain the same embedding type.
  1976. # This ensures correct functionality when a Custom Embedding class is passed as input.
  1977. # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
  1978. if is_deepspeed_zero3_enabled() and not is_quantized:
  1979. import deepspeed
  1980. params = [old_embeddings.weight, new_embeddings.weight]
  1981. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  1982. old_embeddings.weight = new_embeddings.weight
  1983. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  1984. # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
  1985. # will be set to `None` in the resized embeddings.
  1986. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  1987. old_embeddings.padding_idx = None
  1988. else:
  1989. old_embeddings.weight.data = new_embeddings.weight.data
  1990. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  1991. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  1992. old_embeddings.padding_idx = None
  1993. return old_embeddings
  1994. def _get_resized_lm_head(
  1995. self,
  1996. old_lm_head: nn.Linear,
  1997. new_num_tokens: Optional[int] = None,
  1998. transposed: Optional[bool] = False,
  1999. mean_resizing: bool = True,
  2000. ) -> nn.Linear:
  2001. """
  2002. Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
  2003. vectors at the end. Reducing the size will remove vectors from the end
  2004. Args:
  2005. old_lm_head (`torch.nn.Linear`):
  2006. Old lm head liner layer to be resized.
  2007. new_num_tokens (`int`, *optional*):
  2008. New number of tokens in the linear matrix.
  2009. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  2010. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  2011. `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
  2012. to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
  2013. vocab_size` else `vocab_size, lm_head_dim`.
  2014. mean_resizing (`bool`):
  2015. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2016. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2017. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2018. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  2019. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2020. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2021. Return:
  2022. `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
  2023. `None`
  2024. """
  2025. if new_num_tokens is None:
  2026. return old_lm_head
  2027. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2028. if is_deepspeed_zero3_enabled() and not is_quantized:
  2029. import deepspeed
  2030. with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
  2031. old_num_tokens, old_lm_head_dim = (
  2032. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  2033. )
  2034. else:
  2035. old_num_tokens, old_lm_head_dim = (
  2036. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  2037. )
  2038. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  2039. return old_lm_head
  2040. if not isinstance(old_lm_head, nn.Linear):
  2041. raise TypeError(
  2042. f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
  2043. " should either use a different resize function or make sure that `old_lm_head` are an instance of"
  2044. f" {nn.Linear}."
  2045. )
  2046. # Build new lm head
  2047. new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
  2048. has_new_lm_head_bias = old_lm_head.bias is not None
  2049. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  2050. # because the shape of the new embedding layer is used across various modeling files
  2051. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  2052. # to errors when training.
  2053. new_lm_head = nn.Linear(
  2054. *new_lm_head_shape,
  2055. bias=has_new_lm_head_bias,
  2056. device=old_lm_head.weight.device,
  2057. dtype=old_lm_head.weight.dtype,
  2058. )
  2059. if new_num_tokens > old_num_tokens and not mean_resizing:
  2060. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  2061. self._init_weights(new_lm_head)
  2062. elif new_num_tokens > old_num_tokens and mean_resizing:
  2063. # initialize new lm_head weights (in particular added tokens). The new lm_head weights
  2064. # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
  2065. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2066. logger.warning_once(
  2067. "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  2068. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  2069. "To disable this, use `mean_resizing=False`"
  2070. )
  2071. added_num_tokens = new_num_tokens - old_num_tokens
  2072. if is_deepspeed_zero3_enabled() and not is_quantized:
  2073. import deepspeed
  2074. params = [old_lm_head.weight]
  2075. if has_new_lm_head_bias:
  2076. params += [old_lm_head.bias]
  2077. with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
  2078. self._init_added_lm_head_weights_with_mean(
  2079. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  2080. )
  2081. if has_new_lm_head_bias:
  2082. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  2083. else:
  2084. self._init_added_lm_head_weights_with_mean(
  2085. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  2086. )
  2087. if has_new_lm_head_bias:
  2088. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  2089. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  2090. if is_deepspeed_zero3_enabled() and not is_quantized:
  2091. import deepspeed
  2092. params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
  2093. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2094. self._copy_lm_head_original_to_resized(
  2095. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2096. )
  2097. else:
  2098. self._copy_lm_head_original_to_resized(
  2099. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2100. )
  2101. return new_lm_head
  2102. def _init_added_embeddings_weights_with_mean(
  2103. self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  2104. ):
  2105. old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
  2106. mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
  2107. old_centered_embeddings = old_embeddings_weight - mean_embeddings
  2108. covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
  2109. # Check if the covariance is positive definite.
  2110. eigenvalues = torch.linalg.eigvals(covariance)
  2111. is_covariance_psd = bool(
  2112. (covariance == covariance.T).all() and not torch.is_complex(eigenvalues) and (eigenvalues > 0).all()
  2113. )
  2114. if is_covariance_psd:
  2115. # If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
  2116. distribution = torch.distributions.multivariate_normal.MultivariateNormal(
  2117. mean_embeddings, covariance_matrix=1e-9 * covariance
  2118. )
  2119. new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
  2120. sample_shape=(added_num_tokens,)
  2121. ).to(old_embeddings.weight.dtype)
  2122. else:
  2123. # Otherwise, just initialize with the mean. because distribtion will not be created.
  2124. new_embeddings.weight.data[-1 * added_num_tokens :, :] = (
  2125. mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype)
  2126. )
  2127. def _init_added_lm_head_weights_with_mean(
  2128. self,
  2129. old_lm_head,
  2130. new_lm_head,
  2131. old_lm_head_dim,
  2132. old_num_tokens,
  2133. added_num_tokens,
  2134. transposed=False,
  2135. ):
  2136. if transposed:
  2137. # Transpose to the desired shape for the function.
  2138. new_lm_head.weight.data = new_lm_head.weight.data.T
  2139. old_lm_head.weight.data = old_lm_head.weight.data.T
  2140. # The same initilization logic as Embeddings.
  2141. self._init_added_embeddings_weights_with_mean(
  2142. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens
  2143. )
  2144. if transposed:
  2145. # Transpose again to the correct shape.
  2146. new_lm_head.weight.data = new_lm_head.weight.data.T
  2147. old_lm_head.weight.data = old_lm_head.weight.data.T
  2148. def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
  2149. bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
  2150. bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
  2151. new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std)
  2152. def _copy_lm_head_original_to_resized(
  2153. self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  2154. ):
  2155. # Copy old lm head weights to new lm head
  2156. if not transposed:
  2157. new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
  2158. else:
  2159. new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
  2160. # Copy bias weights to new lm head
  2161. if has_new_lm_head_bias:
  2162. new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
  2163. def resize_position_embeddings(self, new_num_position_embeddings: int):
  2164. raise NotImplementedError(
  2165. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  2166. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  2167. )
  2168. def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
  2169. raise NotImplementedError(
  2170. f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  2171. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  2172. )
  2173. def init_weights(self):
  2174. """
  2175. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
  2176. initialization logic in `_init_weights`.
  2177. """
  2178. # Prune heads if needed
  2179. if self.config.pruned_heads:
  2180. self.prune_heads(self.config.pruned_heads)
  2181. if _init_weights:
  2182. # Initialize weights
  2183. self.apply(self._initialize_weights)
  2184. # Tie weights should be skipped when not initializing all weights
  2185. # since from_pretrained(...) calls tie weights anyways
  2186. self.tie_weights()
  2187. def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
  2188. """
  2189. Prunes heads of the base model.
  2190. Arguments:
  2191. heads_to_prune (`Dict[int, List[int]]`):
  2192. Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
  2193. to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
  2194. layer 1 and heads 2 and 3 on layer 2.
  2195. """
  2196. # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
  2197. for layer, heads in heads_to_prune.items():
  2198. union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
  2199. self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
  2200. self.base_model._prune_heads(heads_to_prune)
  2201. def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
  2202. """
  2203. Activates gradient checkpointing for the current model.
  2204. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  2205. activations".
  2206. We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
  2207. the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  2208. Args:
  2209. gradient_checkpointing_kwargs (dict, *optional*):
  2210. Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
  2211. """
  2212. if not self.supports_gradient_checkpointing:
  2213. raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
  2214. if gradient_checkpointing_kwargs is None:
  2215. gradient_checkpointing_kwargs = {"use_reentrant": True}
  2216. gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
  2217. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  2218. # we will fall back to the overwritten `_set_gradient_checkpointing` method
  2219. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  2220. if not _is_using_old_format:
  2221. self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
  2222. else:
  2223. self.apply(partial(self._set_gradient_checkpointing, value=True))
  2224. logger.warning(
  2225. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  2226. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  2227. )
  2228. if getattr(self, "_hf_peft_config_loaded", False):
  2229. # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
  2230. # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
  2231. # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
  2232. # the gradients to make sure the gradient flows.
  2233. self.enable_input_require_grads()
  2234. def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
  2235. is_gradient_checkpointing_set = False
  2236. # Apply it on the top-level module in case the top-level modules supports it
  2237. # for example, LongT5Stack inherits from `PreTrainedModel`.
  2238. if hasattr(self, "gradient_checkpointing"):
  2239. self._gradient_checkpointing_func = gradient_checkpointing_func
  2240. self.gradient_checkpointing = enable
  2241. is_gradient_checkpointing_set = True
  2242. for module in self.modules():
  2243. if hasattr(module, "gradient_checkpointing"):
  2244. module._gradient_checkpointing_func = gradient_checkpointing_func
  2245. module.gradient_checkpointing = enable
  2246. is_gradient_checkpointing_set = True
  2247. if not is_gradient_checkpointing_set:
  2248. raise ValueError(
  2249. f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
  2250. " `gradient_checkpointing` to modules of the model that uses checkpointing."
  2251. )
  2252. def gradient_checkpointing_disable(self):
  2253. """
  2254. Deactivates gradient checkpointing for the current model.
  2255. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  2256. activations".
  2257. """
  2258. if self.supports_gradient_checkpointing:
  2259. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  2260. # we will fall back to the overwritten `_set_gradient_checkpointing` methid
  2261. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  2262. if not _is_using_old_format:
  2263. self._set_gradient_checkpointing(enable=False)
  2264. else:
  2265. logger.warning(
  2266. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  2267. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  2268. )
  2269. self.apply(partial(self._set_gradient_checkpointing, value=False))
  2270. if getattr(self, "_hf_peft_config_loaded", False):
  2271. self.disable_input_require_grads()
  2272. @property
  2273. def is_gradient_checkpointing(self) -> bool:
  2274. """
  2275. Whether gradient checkpointing is activated for this model or not.
  2276. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  2277. activations".
  2278. """
  2279. return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  2280. def save_pretrained(
  2281. self,
  2282. save_directory: Union[str, os.PathLike],
  2283. is_main_process: bool = True,
  2284. state_dict: Optional[dict] = None,
  2285. save_function: Callable = torch.save,
  2286. push_to_hub: bool = False,
  2287. max_shard_size: Union[int, str] = "5GB",
  2288. safe_serialization: bool = True,
  2289. variant: Optional[str] = None,
  2290. token: Optional[Union[str, bool]] = None,
  2291. save_peft_format: bool = True,
  2292. **kwargs,
  2293. ):
  2294. """
  2295. Save a model and its configuration file to a directory, so that it can be re-loaded using the
  2296. [`~PreTrainedModel.from_pretrained`] class method.
  2297. Arguments:
  2298. save_directory (`str` or `os.PathLike`):
  2299. Directory to which to save. Will be created if it doesn't exist.
  2300. is_main_process (`bool`, *optional*, defaults to `True`):
  2301. Whether the process calling this is the main process or not. Useful when in distributed training like
  2302. TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
  2303. the main process to avoid race conditions.
  2304. state_dict (nested dictionary of `torch.Tensor`):
  2305. The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
  2306. save parts of the model or if special precautions need to be taken when recovering the state dictionary
  2307. of a model (like when using model parallelism).
  2308. save_function (`Callable`):
  2309. The function to use to save the state dictionary. Useful on distributed training like TPUs when one
  2310. need to replace `torch.save` by another method.
  2311. push_to_hub (`bool`, *optional*, defaults to `False`):
  2312. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  2313. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  2314. namespace).
  2315. max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
  2316. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  2317. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  2318. We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
  2319. without CPU OOM issues.
  2320. <Tip warning={true}>
  2321. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  2322. which will be bigger than `max_shard_size`.
  2323. </Tip>
  2324. safe_serialization (`bool`, *optional*, defaults to `True`):
  2325. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  2326. variant (`str`, *optional*):
  2327. If specified, weights are saved in the format pytorch_model.<variant>.bin.
  2328. token (`str` or `bool`, *optional*):
  2329. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2330. the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
  2331. save_peft_format (`bool`, *optional*, defaults to `True`):
  2332. For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
  2333. keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can
  2334. disable this behaviours by setting `save_peft_format` to `False`.
  2335. kwargs (`Dict[str, Any]`, *optional*):
  2336. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  2337. """
  2338. use_auth_token = kwargs.pop("use_auth_token", None)
  2339. ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
  2340. if use_auth_token is not None:
  2341. warnings.warn(
  2342. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2343. FutureWarning,
  2344. )
  2345. if token is not None:
  2346. raise ValueError(
  2347. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2348. )
  2349. token = use_auth_token
  2350. if token is not None:
  2351. kwargs["token"] = token
  2352. _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
  2353. hf_quantizer = getattr(self, "hf_quantizer", None)
  2354. quantization_serializable = (
  2355. hf_quantizer is not None
  2356. and isinstance(hf_quantizer, HfQuantizer)
  2357. and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
  2358. )
  2359. if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
  2360. raise ValueError(
  2361. f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
  2362. " the logger on the traceback to understand the reason why the quantized model is not serializable."
  2363. )
  2364. if "save_config" in kwargs:
  2365. warnings.warn(
  2366. "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
  2367. )
  2368. is_main_process = kwargs.pop("save_config")
  2369. if safe_serialization and not is_safetensors_available():
  2370. raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
  2371. if os.path.isfile(save_directory):
  2372. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  2373. return
  2374. os.makedirs(save_directory, exist_ok=True)
  2375. if push_to_hub:
  2376. commit_message = kwargs.pop("commit_message", None)
  2377. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  2378. repo_id = self._create_repo(repo_id, **kwargs)
  2379. files_timestamps = self._get_files_timestamps(save_directory)
  2380. # Only save the model itself if we are using distributed training
  2381. model_to_save = unwrap_model(self)
  2382. # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
  2383. # we currently don't use this setting automatically, but may start to use with v5
  2384. dtype = get_parameter_dtype(model_to_save)
  2385. model_to_save.config.torch_dtype = str(dtype).split(".")[1]
  2386. # Attach architecture to the config
  2387. model_to_save.config.architectures = [model_to_save.__class__.__name__]
  2388. # Unset attn implementation so it can be set to another one when loading back
  2389. model_to_save.config._attn_implementation_autoset = False
  2390. # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
  2391. # loaded from the Hub.
  2392. if self._auto_class is not None:
  2393. custom_object_save(self, save_directory, config=self.config)
  2394. # Save the config
  2395. if is_main_process:
  2396. if not _hf_peft_config_loaded:
  2397. # If the model config has set attributes that should be in the generation config, move them there.
  2398. misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters()
  2399. if self.can_generate() and len(misplaced_generation_parameters) > 0:
  2400. warnings.warn(
  2401. "Moving the following attributes in the config to the generation config: "
  2402. f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
  2403. "generation parameters in the model config, as opposed to in the generation config.",
  2404. UserWarning,
  2405. )
  2406. for param_name, param_value in misplaced_generation_parameters.items():
  2407. setattr(model_to_save.generation_config, param_name, param_value)
  2408. setattr(model_to_save.config, param_name, None)
  2409. model_to_save.config.save_pretrained(save_directory)
  2410. if self.can_generate():
  2411. model_to_save.generation_config.save_pretrained(save_directory)
  2412. if _hf_peft_config_loaded:
  2413. logger.info(
  2414. "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
  2415. )
  2416. state_dict = model_to_save.get_adapter_state_dict()
  2417. if save_peft_format:
  2418. logger.info(
  2419. "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
  2420. )
  2421. peft_state_dict = {}
  2422. for key, value in state_dict.items():
  2423. peft_state_dict[f"base_model.model.{key}"] = value
  2424. state_dict = peft_state_dict
  2425. active_adapter = self.active_adapters()
  2426. if len(active_adapter) > 1:
  2427. raise ValueError(
  2428. "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
  2429. "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
  2430. )
  2431. active_adapter = active_adapter[0]
  2432. current_peft_config = self.peft_config[active_adapter]
  2433. current_peft_config.save_pretrained(save_directory)
  2434. # for offloaded modules
  2435. module_map = {}
  2436. # Save the model
  2437. if state_dict is None:
  2438. # if any model parameters are offloaded, make module map
  2439. if (
  2440. hasattr(self, "hf_device_map")
  2441. and len(set(self.hf_device_map.values())) > 1
  2442. and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
  2443. ):
  2444. warnings.warn(
  2445. "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
  2446. )
  2447. for name, module in model_to_save.named_modules():
  2448. if name == "":
  2449. continue
  2450. module_state_dict = module.state_dict()
  2451. for key in module_state_dict:
  2452. module_map[name + f".{key}"] = module
  2453. state_dict = model_to_save.state_dict()
  2454. # Translate state_dict from smp to hf if saving with smp >= 1.10
  2455. if IS_SAGEMAKER_MP_POST_1_10:
  2456. for smp_to_hf, _ in smp.state.module_manager.translate_functions:
  2457. state_dict = smp_to_hf(state_dict)
  2458. # Handle the case where some state_dict keys shouldn't be saved
  2459. if self._keys_to_ignore_on_save is not None:
  2460. for ignore_key in self._keys_to_ignore_on_save:
  2461. if ignore_key in state_dict.keys():
  2462. del state_dict[ignore_key]
  2463. if safe_serialization:
  2464. # Safetensors does not allow tensor aliasing.
  2465. # We're going to remove aliases before saving
  2466. ptrs = collections.defaultdict(list)
  2467. for name, tensor in state_dict.items():
  2468. # Sometimes in the state_dict we have non-tensor objects.
  2469. # e.g. in bitsandbytes we have some `str` objects in the state_dict
  2470. if isinstance(tensor, torch.Tensor):
  2471. ptrs[id_tensor_storage(tensor)].append(name)
  2472. else:
  2473. # In the non-tensor case, fall back to the pointer of the object itself
  2474. ptrs[id(tensor)].append(name)
  2475. # These are all the pointers of shared tensors
  2476. if hasattr(self, "hf_device_map"):
  2477. # if the model has offloaded parameters, we must check using find_tied_parameters()
  2478. tied_params = find_tied_parameters(self)
  2479. if tied_params:
  2480. tied_names = tied_params[0]
  2481. shared_ptrs = {
  2482. ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
  2483. }
  2484. else:
  2485. shared_ptrs = {}
  2486. else:
  2487. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
  2488. # Recursively descend to find tied weight keys
  2489. _tied_weights_keys = _get_tied_weight_keys(self)
  2490. error_names = []
  2491. to_delete_names = set()
  2492. for names in shared_ptrs.values():
  2493. # Removing the keys which are declared as known duplicates on
  2494. # load. This allows to make sure the name which is kept is consistent.
  2495. if _tied_weights_keys is not None:
  2496. found = 0
  2497. for name in sorted(names):
  2498. matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
  2499. if matches_pattern and name in state_dict:
  2500. found += 1
  2501. if found < len(names):
  2502. to_delete_names.add(name)
  2503. # We are entering a place where the weights and the transformers configuration do NOT match.
  2504. shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
  2505. # Those are actually tensor sharing but disjoint from each other, we can safely clone them
  2506. # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
  2507. for name in disjoint_names:
  2508. state_dict[name] = state_dict[name].clone()
  2509. # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
  2510. # If the link between tensors was done at runtime then `from_pretrained` will not get
  2511. # the key back leading to random tensor. A proper warning will be shown
  2512. # during reload (if applicable), but since the file is not necessarily compatible with
  2513. # the config, better show a proper warning.
  2514. shared_names, identical_names = _find_identical(shared_names, state_dict)
  2515. # delete tensors that have identical storage
  2516. for inames in identical_names:
  2517. known = inames.intersection(to_delete_names)
  2518. for name in known:
  2519. del state_dict[name]
  2520. unknown = inames.difference(to_delete_names)
  2521. if len(unknown) > 1:
  2522. error_names.append(unknown)
  2523. if shared_names:
  2524. error_names.append(set(shared_names))
  2525. if len(error_names) > 0:
  2526. raise RuntimeError(
  2527. f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
  2528. )
  2529. # Shard the model if it is too big.
  2530. if not _hf_peft_config_loaded:
  2531. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
  2532. weights_name = _add_variant(weights_name, variant)
  2533. else:
  2534. weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
  2535. filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
  2536. state_dict_split = split_torch_state_dict_into_shards(
  2537. state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
  2538. )
  2539. # Save index if sharded
  2540. index = None
  2541. if state_dict_split.is_sharded:
  2542. index = {
  2543. "metadata": state_dict_split.metadata,
  2544. "weight_map": state_dict_split.tensor_to_filename,
  2545. }
  2546. # Clean the folder from a previous save
  2547. for filename in os.listdir(save_directory):
  2548. full_filename = os.path.join(save_directory, filename)
  2549. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  2550. # in distributed settings to avoid race conditions.
  2551. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
  2552. # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
  2553. filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
  2554. reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
  2555. if (
  2556. filename.startswith(weights_no_suffix)
  2557. and os.path.isfile(full_filename)
  2558. and filename not in state_dict_split.filename_to_tensors.keys()
  2559. and is_main_process
  2560. and reg.fullmatch(filename_no_suffix) is not None
  2561. ):
  2562. os.remove(full_filename)
  2563. # Save the model
  2564. filename_to_tensors = state_dict_split.filename_to_tensors.items()
  2565. if module_map:
  2566. filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
  2567. for shard_file, tensors in filename_to_tensors:
  2568. shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
  2569. # remake shard with onloaded parameters if necessary
  2570. if module_map:
  2571. if accelerate_version < version.parse("0.31"):
  2572. raise ImportError(
  2573. f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
  2574. f"Please upgrade accelerate with `pip install -U accelerate`"
  2575. )
  2576. # init state_dict for this shard
  2577. shard_state_dict = {name: "" for name in shard}
  2578. for module_name in shard:
  2579. module = module_map[module_name]
  2580. # update state dict with onloaded parameters
  2581. shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
  2582. # assign shard to be the completed state dict
  2583. shard = shard_state_dict
  2584. del shard_state_dict
  2585. gc.collect()
  2586. if safe_serialization:
  2587. # At some point we will need to deal better with save_function (used for TPU and other distributed
  2588. # joyfulness), but for now this enough.
  2589. safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
  2590. else:
  2591. save_function(shard, os.path.join(save_directory, shard_file))
  2592. if index is None:
  2593. path_to_weights = os.path.join(save_directory, weights_name)
  2594. logger.info(f"Model weights saved in {path_to_weights}")
  2595. else:
  2596. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
  2597. save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
  2598. # Save the index as well
  2599. with open(save_index_file, "w", encoding="utf-8") as f:
  2600. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  2601. f.write(content)
  2602. logger.info(
  2603. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  2604. f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
  2605. f"index located at {save_index_file}."
  2606. )
  2607. if push_to_hub:
  2608. # Eventually create an empty model card
  2609. model_card = create_and_tag_model_card(
  2610. repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
  2611. )
  2612. # Update model card if needed:
  2613. model_card.save(os.path.join(save_directory, "README.md"))
  2614. self._upload_modified_files(
  2615. save_directory,
  2616. repo_id,
  2617. files_timestamps,
  2618. commit_message=commit_message,
  2619. token=token,
  2620. )
  2621. @wraps(PushToHubMixin.push_to_hub)
  2622. def push_to_hub(self, *args, **kwargs):
  2623. tags = self.model_tags if self.model_tags is not None else []
  2624. tags_kwargs = kwargs.get("tags", [])
  2625. if isinstance(tags_kwargs, str):
  2626. tags_kwargs = [tags_kwargs]
  2627. for tag in tags_kwargs:
  2628. if tag not in tags:
  2629. tags.append(tag)
  2630. if tags:
  2631. kwargs["tags"] = tags
  2632. return super().push_to_hub(*args, **kwargs)
  2633. def get_memory_footprint(self, return_buffers=True):
  2634. r"""
  2635. Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
  2636. Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
  2637. PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
  2638. Arguments:
  2639. return_buffers (`bool`, *optional*, defaults to `True`):
  2640. Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
  2641. are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
  2642. norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
  2643. """
  2644. mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
  2645. if return_buffers:
  2646. mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
  2647. mem = mem + mem_bufs
  2648. return mem
  2649. @wraps(torch.nn.Module.cuda)
  2650. def cuda(self, *args, **kwargs):
  2651. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  2652. raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
  2653. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  2654. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  2655. if getattr(self, "is_loaded_in_8bit", False):
  2656. raise ValueError(
  2657. "Calling `cuda()` is not supported for `8-bit` quantized models. "
  2658. " Please use the model as it is, since the model has already been set to the correct devices."
  2659. )
  2660. elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
  2661. raise ValueError(
  2662. "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
  2663. f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
  2664. )
  2665. else:
  2666. return super().cuda(*args, **kwargs)
  2667. @wraps(torch.nn.Module.to)
  2668. def to(self, *args, **kwargs):
  2669. # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
  2670. # the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
  2671. dtype_present_in_args = "dtype" in kwargs
  2672. if not dtype_present_in_args:
  2673. for arg in args:
  2674. if isinstance(arg, torch.dtype):
  2675. dtype_present_in_args = True
  2676. break
  2677. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  2678. raise ValueError("`.to` is not supported for HQQ-quantized models.")
  2679. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  2680. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  2681. if dtype_present_in_args:
  2682. raise ValueError(
  2683. "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
  2684. " desired `dtype` by passing the correct `torch_dtype` argument."
  2685. )
  2686. if getattr(self, "is_loaded_in_8bit", False):
  2687. raise ValueError(
  2688. "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
  2689. " model has already been set to the correct devices and casted to the correct `dtype`."
  2690. )
  2691. elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
  2692. raise ValueError(
  2693. "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
  2694. f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
  2695. )
  2696. elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
  2697. if dtype_present_in_args:
  2698. raise ValueError(
  2699. "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
  2700. " `dtype` by passing the correct `torch_dtype` argument."
  2701. )
  2702. return super().to(*args, **kwargs)
  2703. def half(self, *args):
  2704. # Checks if the model is quantized
  2705. if getattr(self, "is_quantized", False):
  2706. raise ValueError(
  2707. "`.half()` is not supported for quantized model. Please use the model as it is, since the"
  2708. " model has already been casted to the correct `dtype`."
  2709. )
  2710. else:
  2711. return super().half(*args)
  2712. def float(self, *args):
  2713. # Checks if the model is quantized
  2714. if getattr(self, "is_quantized", False):
  2715. raise ValueError(
  2716. "`.float()` is not supported for quantized model. Please use the model as it is, since the"
  2717. " model has already been casted to the correct `dtype`."
  2718. )
  2719. else:
  2720. return super().float(*args)
  2721. @classmethod
  2722. def from_pretrained(
  2723. cls,
  2724. pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
  2725. *model_args,
  2726. config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
  2727. cache_dir: Optional[Union[str, os.PathLike]] = None,
  2728. ignore_mismatched_sizes: bool = False,
  2729. force_download: bool = False,
  2730. local_files_only: bool = False,
  2731. token: Optional[Union[str, bool]] = None,
  2732. revision: str = "main",
  2733. use_safetensors: bool = None,
  2734. weights_only: bool = True,
  2735. **kwargs,
  2736. ) -> "PreTrainedModel":
  2737. r"""
  2738. Instantiate a pretrained pytorch model from a pre-trained model configuration.
  2739. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  2740. the model, you should first set it back in training mode with `model.train()`.
  2741. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
  2742. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
  2743. task.
  2744. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
  2745. weights are discarded.
  2746. If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
  2747. in using the `meta` device and brought into memory once an input is passed through that layer regardless of
  2748. `low_cpu_mem_usage`.
  2749. Parameters:
  2750. pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
  2751. Can be either:
  2752. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  2753. - A path to a *directory* containing model weights saved using
  2754. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  2755. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
  2756. this case, `from_tf` should be set to `True` and a configuration object should be provided as
  2757. `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
  2758. PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
  2759. - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
  2760. `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
  2761. `True`.
  2762. - `None` if you are both providing the configuration and state dictionary (resp. with keyword
  2763. arguments `config` and `state_dict`).
  2764. model_args (sequence of positional arguments, *optional*):
  2765. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  2766. config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
  2767. Can be either:
  2768. - an instance of a class derived from [`PretrainedConfig`],
  2769. - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
  2770. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  2771. be automatically loaded when:
  2772. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  2773. model).
  2774. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  2775. save directory.
  2776. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  2777. configuration JSON file named *config.json* is found in the directory.
  2778. state_dict (`Dict[str, torch.Tensor]`, *optional*):
  2779. A state dictionary to use instead of a state dictionary loaded from saved weights file.
  2780. This option can be used if you want to create a model from a pretrained configuration but load your own
  2781. weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
  2782. [`~PreTrainedModel.from_pretrained`] is not a simpler option.
  2783. cache_dir (`Union[str, os.PathLike]`, *optional*):
  2784. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  2785. standard cache should not be used.
  2786. from_tf (`bool`, *optional*, defaults to `False`):
  2787. Load the model weights from a TensorFlow checkpoint save file (see docstring of
  2788. `pretrained_model_name_or_path` argument).
  2789. from_flax (`bool`, *optional*, defaults to `False`):
  2790. Load the model weights from a Flax checkpoint save file (see docstring of
  2791. `pretrained_model_name_or_path` argument).
  2792. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  2793. Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
  2794. as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
  2795. checkpoint with 3 labels).
  2796. force_download (`bool`, *optional*, defaults to `False`):
  2797. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  2798. cached versions if they exist.
  2799. resume_download:
  2800. Deprecated and ignored. All downloads are now resumed by default when possible.
  2801. Will be removed in v5 of Transformers.
  2802. proxies (`Dict[str, str]`, *optional*):
  2803. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  2804. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  2805. output_loading_info(`bool`, *optional*, defaults to `False`):
  2806. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
  2807. local_files_only(`bool`, *optional*, defaults to `False`):
  2808. Whether or not to only look at local files (i.e., do not try to download the model).
  2809. token (`str` or `bool`, *optional*):
  2810. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2811. the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
  2812. revision (`str`, *optional*, defaults to `"main"`):
  2813. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  2814. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  2815. identifier allowed by git.
  2816. <Tip>
  2817. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  2818. </Tip>
  2819. mirror (`str`, *optional*):
  2820. Mirror source to accelerate downloads in China. If you are from China and have an accessibility
  2821. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
  2822. Please refer to the mirror site for more information.
  2823. _fast_init(`bool`, *optional*, defaults to `True`):
  2824. Whether or not to disable fast initialization.
  2825. <Tip warning={true}>
  2826. One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ <
  2827. 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See
  2828. [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information.
  2829. </Tip>
  2830. attn_implementation (`str`, *optional*):
  2831. The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
  2832. > Parameters for big model inference
  2833. low_cpu_mem_usage(`bool`, *optional*):
  2834. Tries not to use more than 1x model size in CPU memory (including peak memory) while loading the model.
  2835. Generally should be combined with a `device_map` (such as `"auto"`) for best results.
  2836. This is an experimental feature and a subject to change at any moment.
  2837. </Tip>
  2838. If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
  2839. `device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
  2840. this should still be enabled if you are passing in a `device_map`.
  2841. </Tip>
  2842. torch_dtype (`str` or `torch.dtype`, *optional*):
  2843. Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
  2844. are:
  2845. 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
  2846. `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified
  2847. - the model will get loaded in `torch.float` (fp32).
  2848. 2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
  2849. attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
  2850. the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
  2851. using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
  2852. the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
  2853. 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
  2854. <Tip>
  2855. For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
  2856. reach out to the authors and ask them to add this information to the model's card and to insert the
  2857. `torch_dtype` entry in `config.json` on the hub.
  2858. </Tip>
  2859. device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
  2860. A map that specifies where each submodule should go. It doesn't need to be refined to each
  2861. parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
  2862. same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
  2863. like `1`) on which the model will be allocated, the device map will map the entire model to this
  2864. device. Passing `device_map = 0` means put the whole model on GPU 0.
  2865. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
  2866. more information about each option see [designing a device
  2867. map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
  2868. max_memory (`Dict`, *optional*):
  2869. A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
  2870. GPU and the available CPU RAM if unset.
  2871. offload_folder (`str` or `os.PathLike`, *optional*):
  2872. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  2873. offload_state_dict (`bool`, *optional*):
  2874. If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
  2875. RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
  2876. `True` when there is some disk offload.
  2877. offload_buffers (`bool`, *optional*):
  2878. Whether or not to offload the buffers with the model parameters.
  2879. quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
  2880. A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
  2881. bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
  2882. `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
  2883. quantizations and not preferred. consider inserting all such arguments into quantization_config
  2884. instead.
  2885. subfolder (`str`, *optional*, defaults to `""`):
  2886. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  2887. specify the folder name here.
  2888. variant (`str`, *optional*):
  2889. If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
  2890. ignored when using `from_tf` or `from_flax`.
  2891. use_safetensors (`bool`, *optional*, defaults to `None`):
  2892. Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
  2893. is not installed, it will be set to `False`.
  2894. weights_only (`bool`, *optional*, defaults to `True`):
  2895. Indicates whether unpickler should be restricted to loading only tensors, primitive types,
  2896. dictionaries and any types added via torch.serialization.add_safe_globals().
  2897. When set to False, we can load wrapper tensor subclass weights.
  2898. kwargs (remaining dictionary of keyword arguments, *optional*):
  2899. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  2900. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  2901. automatically loaded:
  2902. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  2903. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  2904. already been done)
  2905. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  2906. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  2907. corresponds to a configuration attribute will be used to override said attribute with the
  2908. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  2909. will be passed to the underlying model's `__init__` function.
  2910. <Tip>
  2911. Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
  2912. use this method in a firewalled environment.
  2913. </Tip>
  2914. Examples:
  2915. ```python
  2916. >>> from transformers import BertConfig, BertModel
  2917. >>> # Download model and configuration from huggingface.co and cache.
  2918. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
  2919. >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
  2920. >>> model = BertModel.from_pretrained("./test/saved_model/")
  2921. >>> # Update configuration during loading.
  2922. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
  2923. >>> assert model.config.output_attentions == True
  2924. >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
  2925. >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
  2926. >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
  2927. >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
  2928. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
  2929. ```
  2930. * `low_cpu_mem_usage` algorithm:
  2931. This is an experimental function that loads the model using ~1x model size CPU memory
  2932. Here is how it works:
  2933. 1. save which state_dict keys we have
  2934. 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
  2935. 3. after the model has been instantiated switch to the meta device all params/buffers that
  2936. are going to be replaced from the loaded state_dict
  2937. 4. load state_dict 2nd time
  2938. 5. replace the params/buffers from the state_dict
  2939. Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
  2940. """
  2941. state_dict = kwargs.pop("state_dict", None)
  2942. from_tf = kwargs.pop("from_tf", False)
  2943. from_flax = kwargs.pop("from_flax", False)
  2944. resume_download = kwargs.pop("resume_download", None)
  2945. proxies = kwargs.pop("proxies", None)
  2946. output_loading_info = kwargs.pop("output_loading_info", False)
  2947. use_auth_token = kwargs.pop("use_auth_token", None)
  2948. trust_remote_code = kwargs.pop("trust_remote_code", None)
  2949. _ = kwargs.pop("mirror", None)
  2950. from_pipeline = kwargs.pop("_from_pipeline", None)
  2951. from_auto_class = kwargs.pop("_from_auto", False)
  2952. _fast_init = kwargs.pop("_fast_init", True)
  2953. torch_dtype = kwargs.pop("torch_dtype", None)
  2954. low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None)
  2955. device_map = kwargs.pop("device_map", None)
  2956. max_memory = kwargs.pop("max_memory", None)
  2957. offload_folder = kwargs.pop("offload_folder", None)
  2958. offload_state_dict = kwargs.pop("offload_state_dict", False)
  2959. offload_buffers = kwargs.pop("offload_buffers", False)
  2960. load_in_8bit = kwargs.pop("load_in_8bit", False)
  2961. load_in_4bit = kwargs.pop("load_in_4bit", False)
  2962. quantization_config = kwargs.pop("quantization_config", None)
  2963. subfolder = kwargs.pop("subfolder", "")
  2964. commit_hash = kwargs.pop("_commit_hash", None)
  2965. variant = kwargs.pop("variant", None)
  2966. adapter_kwargs = kwargs.pop("adapter_kwargs", {})
  2967. adapter_name = kwargs.pop("adapter_name", "default")
  2968. use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
  2969. generation_config = kwargs.pop("generation_config", None)
  2970. gguf_file = kwargs.pop("gguf_file", None)
  2971. # Cache path to the GGUF file
  2972. gguf_path = None
  2973. if is_fsdp_enabled():
  2974. low_cpu_mem_usage = True
  2975. if use_auth_token is not None:
  2976. warnings.warn(
  2977. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2978. FutureWarning,
  2979. )
  2980. if token is not None:
  2981. raise ValueError(
  2982. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2983. )
  2984. token = use_auth_token
  2985. if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
  2986. adapter_kwargs["token"] = token
  2987. if use_safetensors is None and not is_safetensors_available():
  2988. use_safetensors = False
  2989. if trust_remote_code is True:
  2990. logger.warning(
  2991. "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
  2992. " ignored."
  2993. )
  2994. if gguf_file is not None and not is_accelerate_available():
  2995. raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
  2996. if commit_hash is None:
  2997. if not isinstance(config, PretrainedConfig):
  2998. # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
  2999. resolved_config_file = cached_file(
  3000. pretrained_model_name_or_path,
  3001. CONFIG_NAME,
  3002. cache_dir=cache_dir,
  3003. force_download=force_download,
  3004. resume_download=resume_download,
  3005. proxies=proxies,
  3006. local_files_only=local_files_only,
  3007. token=token,
  3008. revision=revision,
  3009. subfolder=subfolder,
  3010. _raise_exceptions_for_gated_repo=False,
  3011. _raise_exceptions_for_missing_entries=False,
  3012. _raise_exceptions_for_connection_errors=False,
  3013. )
  3014. commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
  3015. else:
  3016. commit_hash = getattr(config, "_commit_hash", None)
  3017. if is_peft_available():
  3018. _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
  3019. if _adapter_model_path is None:
  3020. _adapter_model_path = find_adapter_config_file(
  3021. pretrained_model_name_or_path,
  3022. cache_dir=cache_dir,
  3023. force_download=force_download,
  3024. resume_download=resume_download,
  3025. proxies=proxies,
  3026. local_files_only=local_files_only,
  3027. _commit_hash=commit_hash,
  3028. **adapter_kwargs,
  3029. )
  3030. if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
  3031. with open(_adapter_model_path, "r", encoding="utf-8") as f:
  3032. _adapter_model_path = pretrained_model_name_or_path
  3033. pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
  3034. else:
  3035. _adapter_model_path = None
  3036. # change device_map into a map if we passed an int, a str or a torch.device
  3037. if isinstance(device_map, torch.device):
  3038. device_map = {"": device_map}
  3039. elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  3040. try:
  3041. device_map = {"": torch.device(device_map)}
  3042. except RuntimeError:
  3043. raise ValueError(
  3044. "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
  3045. f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
  3046. )
  3047. elif isinstance(device_map, int):
  3048. if device_map < 0:
  3049. raise ValueError(
  3050. "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
  3051. )
  3052. else:
  3053. device_map = {"": device_map}
  3054. if device_map is not None:
  3055. if low_cpu_mem_usage is None:
  3056. low_cpu_mem_usage = True
  3057. elif not low_cpu_mem_usage:
  3058. raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
  3059. if low_cpu_mem_usage:
  3060. if is_deepspeed_zero3_enabled():
  3061. raise ValueError(
  3062. "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
  3063. )
  3064. elif not is_accelerate_available():
  3065. raise ImportError(
  3066. f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  3067. )
  3068. # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
  3069. if load_in_4bit or load_in_8bit:
  3070. if quantization_config is not None:
  3071. raise ValueError(
  3072. "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
  3073. "`quantization_config` argument at the same time."
  3074. )
  3075. # preparing BitsAndBytesConfig from kwargs
  3076. config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
  3077. config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
  3078. quantization_config, kwargs = BitsAndBytesConfig.from_dict(
  3079. config_dict=config_dict, return_unused_kwargs=True, **kwargs
  3080. )
  3081. logger.warning(
  3082. "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
  3083. "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
  3084. )
  3085. from_pt = not (from_tf | from_flax)
  3086. user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
  3087. if from_pipeline is not None:
  3088. user_agent["using_pipeline"] = from_pipeline
  3089. if is_offline_mode() and not local_files_only:
  3090. logger.info("Offline mode: forcing local_files_only=True")
  3091. local_files_only = True
  3092. # Load config if we don't provide a configuration
  3093. if not isinstance(config, PretrainedConfig):
  3094. config_path = config if config is not None else pretrained_model_name_or_path
  3095. config, model_kwargs = cls.config_class.from_pretrained(
  3096. config_path,
  3097. cache_dir=cache_dir,
  3098. return_unused_kwargs=True,
  3099. force_download=force_download,
  3100. resume_download=resume_download,
  3101. proxies=proxies,
  3102. local_files_only=local_files_only,
  3103. token=token,
  3104. revision=revision,
  3105. subfolder=subfolder,
  3106. _from_auto=from_auto_class,
  3107. _from_pipeline=from_pipeline,
  3108. **kwargs,
  3109. )
  3110. else:
  3111. # In case one passes a config to `from_pretrained` + "attn_implementation"
  3112. # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
  3113. # Please see: https://github.com/huggingface/transformers/issues/28038
  3114. # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
  3115. # we pop attn_implementation from the kwargs but this handles the case where users
  3116. # passes manually the config to `from_pretrained`.
  3117. config = copy.deepcopy(config)
  3118. kwarg_attn_imp = kwargs.pop("attn_implementation", None)
  3119. if kwarg_attn_imp is not None:
  3120. config._attn_implementation = kwarg_attn_imp
  3121. model_kwargs = kwargs
  3122. pre_quantized = getattr(config, "quantization_config", None) is not None
  3123. if pre_quantized or quantization_config is not None:
  3124. if pre_quantized:
  3125. config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
  3126. config.quantization_config, quantization_config
  3127. )
  3128. else:
  3129. config.quantization_config = quantization_config
  3130. hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
  3131. else:
  3132. hf_quantizer = None
  3133. if hf_quantizer is not None:
  3134. hf_quantizer.validate_environment(
  3135. torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
  3136. )
  3137. torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
  3138. device_map = hf_quantizer.update_device_map(device_map)
  3139. # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
  3140. user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
  3141. # Force-set to `True` for more mem efficiency
  3142. if low_cpu_mem_usage is None:
  3143. low_cpu_mem_usage = True
  3144. logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.")
  3145. is_quantized = hf_quantizer is not None
  3146. # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
  3147. # index of the files.
  3148. is_sharded = False
  3149. sharded_metadata = None
  3150. # Load model
  3151. loading_info = None
  3152. # Keep in fp32 modules
  3153. keep_in_fp32_modules = None
  3154. use_keep_in_fp32_modules = False
  3155. if gguf_file is not None and hf_quantizer is not None:
  3156. raise ValueError(
  3157. "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
  3158. )
  3159. if pretrained_model_name_or_path is not None and gguf_file is None:
  3160. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  3161. is_local = os.path.isdir(pretrained_model_name_or_path)
  3162. if is_local:
  3163. if from_tf and os.path.isfile(
  3164. os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
  3165. ):
  3166. # Load from a TF 1.0 checkpoint in priority if from_tf
  3167. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
  3168. elif from_tf and os.path.isfile(
  3169. os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
  3170. ):
  3171. # Load from a TF 2.0 checkpoint in priority if from_tf
  3172. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
  3173. elif from_flax and os.path.isfile(
  3174. os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  3175. ):
  3176. # Load from a Flax checkpoint in priority if from_flax
  3177. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  3178. elif use_safetensors is not False and os.path.isfile(
  3179. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
  3180. ):
  3181. # Load from a safetensors checkpoint
  3182. archive_file = os.path.join(
  3183. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
  3184. )
  3185. elif use_safetensors is not False and os.path.isfile(
  3186. os.path.join(
  3187. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
  3188. )
  3189. ):
  3190. # Load from a sharded safetensors checkpoint
  3191. archive_file = os.path.join(
  3192. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
  3193. )
  3194. is_sharded = True
  3195. elif not use_safetensors and os.path.isfile(
  3196. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
  3197. ):
  3198. # Load from a PyTorch checkpoint
  3199. archive_file = os.path.join(
  3200. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
  3201. )
  3202. elif not use_safetensors and os.path.isfile(
  3203. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
  3204. ):
  3205. # Load from a sharded PyTorch checkpoint
  3206. archive_file = os.path.join(
  3207. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
  3208. )
  3209. is_sharded = True
  3210. # At this stage we don't have a weight file so we will raise an error.
  3211. elif not use_safetensors and (
  3212. os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
  3213. or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
  3214. ):
  3215. raise EnvironmentError(
  3216. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
  3217. f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
  3218. " `from_tf=True` to load this model from those weights."
  3219. )
  3220. elif not use_safetensors and os.path.isfile(
  3221. os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  3222. ):
  3223. raise EnvironmentError(
  3224. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
  3225. f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
  3226. " to load this model from those weights."
  3227. )
  3228. elif use_safetensors:
  3229. raise EnvironmentError(
  3230. f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
  3231. f" {pretrained_model_name_or_path}."
  3232. )
  3233. else:
  3234. raise EnvironmentError(
  3235. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
  3236. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
  3237. f" {pretrained_model_name_or_path}."
  3238. )
  3239. elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
  3240. archive_file = pretrained_model_name_or_path
  3241. is_local = True
  3242. elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
  3243. if not from_tf:
  3244. raise ValueError(
  3245. f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
  3246. "from_tf to True to load from this checkpoint."
  3247. )
  3248. archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
  3249. is_local = True
  3250. elif is_remote_url(pretrained_model_name_or_path):
  3251. filename = pretrained_model_name_or_path
  3252. resolved_archive_file = download_url(pretrained_model_name_or_path)
  3253. else:
  3254. # set correct filename
  3255. if from_tf:
  3256. filename = TF2_WEIGHTS_NAME
  3257. elif from_flax:
  3258. filename = FLAX_WEIGHTS_NAME
  3259. elif use_safetensors is not False:
  3260. filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
  3261. else:
  3262. filename = _add_variant(WEIGHTS_NAME, variant)
  3263. try:
  3264. # Load from URL or cache if already cached
  3265. cached_file_kwargs = {
  3266. "cache_dir": cache_dir,
  3267. "force_download": force_download,
  3268. "proxies": proxies,
  3269. "resume_download": resume_download,
  3270. "local_files_only": local_files_only,
  3271. "token": token,
  3272. "user_agent": user_agent,
  3273. "revision": revision,
  3274. "subfolder": subfolder,
  3275. "_raise_exceptions_for_gated_repo": False,
  3276. "_raise_exceptions_for_missing_entries": False,
  3277. "_commit_hash": commit_hash,
  3278. }
  3279. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  3280. # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
  3281. # result when internet is up, the repo and revision exist, but the file does not.
  3282. if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
  3283. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  3284. resolved_archive_file = cached_file(
  3285. pretrained_model_name_or_path,
  3286. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  3287. **cached_file_kwargs,
  3288. )
  3289. if resolved_archive_file is not None:
  3290. is_sharded = True
  3291. elif use_safetensors:
  3292. if revision == "main":
  3293. resolved_archive_file, revision, is_sharded = auto_conversion(
  3294. pretrained_model_name_or_path, **cached_file_kwargs
  3295. )
  3296. cached_file_kwargs["revision"] = revision
  3297. if resolved_archive_file is None:
  3298. raise EnvironmentError(
  3299. f"{pretrained_model_name_or_path} does not appear to have a file named"
  3300. f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
  3301. "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
  3302. "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
  3303. )
  3304. else:
  3305. # This repo has no safetensors file of any kind, we switch to PyTorch.
  3306. filename = _add_variant(WEIGHTS_NAME, variant)
  3307. resolved_archive_file = cached_file(
  3308. pretrained_model_name_or_path, filename, **cached_file_kwargs
  3309. )
  3310. if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
  3311. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  3312. resolved_archive_file = cached_file(
  3313. pretrained_model_name_or_path,
  3314. _add_variant(WEIGHTS_INDEX_NAME, variant),
  3315. **cached_file_kwargs,
  3316. )
  3317. if resolved_archive_file is not None:
  3318. is_sharded = True
  3319. if not local_files_only and not is_offline_mode():
  3320. if resolved_archive_file is not None:
  3321. if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
  3322. # If the PyTorch file was found, check if there is a safetensors file on the repository
  3323. # If there is no safetensors file on the repositories, start an auto conversion
  3324. safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
  3325. has_file_kwargs = {
  3326. "revision": revision,
  3327. "proxies": proxies,
  3328. "token": token,
  3329. "cache_dir": cache_dir,
  3330. "local_files_only": local_files_only,
  3331. }
  3332. cached_file_kwargs = {
  3333. "cache_dir": cache_dir,
  3334. "force_download": force_download,
  3335. "resume_download": resume_download,
  3336. "local_files_only": local_files_only,
  3337. "user_agent": user_agent,
  3338. "subfolder": subfolder,
  3339. "_raise_exceptions_for_gated_repo": False,
  3340. "_raise_exceptions_for_missing_entries": False,
  3341. "_commit_hash": commit_hash,
  3342. **has_file_kwargs,
  3343. }
  3344. if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
  3345. Thread(
  3346. target=auto_conversion,
  3347. args=(pretrained_model_name_or_path,),
  3348. kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
  3349. name="Thread-autoconversion",
  3350. ).start()
  3351. else:
  3352. # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
  3353. # We try those to give a helpful error message.
  3354. has_file_kwargs = {
  3355. "revision": revision,
  3356. "proxies": proxies,
  3357. "token": token,
  3358. "cache_dir": cache_dir,
  3359. "local_files_only": local_files_only,
  3360. }
  3361. if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
  3362. raise EnvironmentError(
  3363. f"{pretrained_model_name_or_path} does not appear to have a file named"
  3364. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
  3365. " Use `from_tf=True` to load this model from those weights."
  3366. )
  3367. elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
  3368. raise EnvironmentError(
  3369. f"{pretrained_model_name_or_path} does not appear to have a file named"
  3370. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
  3371. " `from_flax=True` to load this model from those weights."
  3372. )
  3373. elif variant is not None and has_file(
  3374. pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
  3375. ):
  3376. raise EnvironmentError(
  3377. f"{pretrained_model_name_or_path} does not appear to have a file named"
  3378. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
  3379. f" {variant}. Use `variant=None` to load this model from those weights."
  3380. )
  3381. else:
  3382. raise EnvironmentError(
  3383. f"{pretrained_model_name_or_path} does not appear to have a file named"
  3384. f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
  3385. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
  3386. )
  3387. except EnvironmentError:
  3388. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  3389. # to the original exception.
  3390. raise
  3391. except Exception as e:
  3392. # For any other exception, we throw a generic error.
  3393. raise EnvironmentError(
  3394. f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
  3395. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  3396. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  3397. f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
  3398. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
  3399. ) from e
  3400. if is_local:
  3401. logger.info(f"loading weights file {archive_file}")
  3402. resolved_archive_file = archive_file
  3403. else:
  3404. logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
  3405. elif gguf_file:
  3406. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  3407. # Case 1: the GGUF file is present locally
  3408. if os.path.isfile(gguf_file):
  3409. gguf_path = gguf_file
  3410. # Case 2: The GGUF path is a location on the Hub
  3411. # Load from URL or cache if already cached
  3412. else:
  3413. cached_file_kwargs = {
  3414. "cache_dir": cache_dir,
  3415. "force_download": force_download,
  3416. "proxies": proxies,
  3417. "resume_download": resume_download,
  3418. "local_files_only": local_files_only,
  3419. "token": token,
  3420. "user_agent": user_agent,
  3421. "revision": revision,
  3422. "subfolder": subfolder,
  3423. "_raise_exceptions_for_gated_repo": False,
  3424. "_raise_exceptions_for_missing_entries": False,
  3425. "_commit_hash": commit_hash,
  3426. }
  3427. gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
  3428. state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True)["tensors"]
  3429. resolved_archive_file = None
  3430. is_sharded = False
  3431. else:
  3432. resolved_archive_file = None
  3433. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  3434. if is_sharded:
  3435. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  3436. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
  3437. pretrained_model_name_or_path,
  3438. resolved_archive_file,
  3439. cache_dir=cache_dir,
  3440. force_download=force_download,
  3441. proxies=proxies,
  3442. resume_download=resume_download,
  3443. local_files_only=local_files_only,
  3444. token=token,
  3445. user_agent=user_agent,
  3446. revision=revision,
  3447. subfolder=subfolder,
  3448. _commit_hash=commit_hash,
  3449. )
  3450. if (
  3451. is_safetensors_available()
  3452. and isinstance(resolved_archive_file, str)
  3453. and resolved_archive_file.endswith(".safetensors")
  3454. ):
  3455. with safe_open(resolved_archive_file, framework="pt") as f:
  3456. metadata = f.metadata()
  3457. if metadata.get("format") == "pt":
  3458. pass
  3459. elif metadata.get("format") == "tf":
  3460. from_tf = True
  3461. logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
  3462. elif metadata.get("format") == "flax":
  3463. from_flax = True
  3464. logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
  3465. elif metadata.get("format") == "mlx":
  3466. # This is a mlx file, we assume weights are compatible with pt
  3467. pass
  3468. else:
  3469. raise ValueError(
  3470. f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
  3471. )
  3472. from_pt = not (from_tf | from_flax)
  3473. # load pt weights early so that we know which dtype to init the model under
  3474. if from_pt:
  3475. if not is_sharded and state_dict is None:
  3476. # Time to load the checkpoint
  3477. state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
  3478. # set dtype to instantiate the model under:
  3479. # 1. If torch_dtype is not None, we use that dtype
  3480. # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
  3481. # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
  3482. # we also may have config.torch_dtype available, but we won't rely on it till v5
  3483. dtype_orig = None
  3484. if torch_dtype is not None:
  3485. if isinstance(torch_dtype, str):
  3486. if torch_dtype == "auto":
  3487. if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
  3488. torch_dtype = config.torch_dtype
  3489. logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object")
  3490. else:
  3491. if is_sharded and "dtype" in sharded_metadata:
  3492. torch_dtype = sharded_metadata["dtype"]
  3493. elif not is_sharded:
  3494. torch_dtype = get_state_dict_dtype(state_dict)
  3495. else:
  3496. one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
  3497. torch_dtype = get_state_dict_dtype(one_state_dict)
  3498. del one_state_dict # free CPU memory
  3499. logger.info(
  3500. "Since the `torch_dtype` attribute can't be found in model's config object, "
  3501. "will use torch_dtype={torch_dtype} as derived from model's weights"
  3502. )
  3503. elif hasattr(torch, torch_dtype):
  3504. torch_dtype = getattr(torch, torch_dtype)
  3505. else:
  3506. raise ValueError(
  3507. f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
  3508. )
  3509. dtype_orig = cls._set_default_torch_dtype(torch_dtype)
  3510. # Check if `_keep_in_fp32_modules` is not None
  3511. use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
  3512. (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
  3513. )
  3514. if is_sharded:
  3515. loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
  3516. else:
  3517. loaded_state_dict_keys = list(state_dict.keys())
  3518. if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
  3519. # In case some weights need to be kept in float32 and accelerate is not installed,
  3520. # we later on want to take the path where state_dict is not None, that is the one
  3521. # that do not require accelerate.
  3522. state_dict = None
  3523. config.name_or_path = pretrained_model_name_or_path
  3524. # Instantiate model.
  3525. init_contexts = [no_init_weights(_enable=_fast_init)]
  3526. if is_deepspeed_zero3_enabled() and not is_quantized:
  3527. import deepspeed
  3528. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  3529. init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
  3530. elif low_cpu_mem_usage:
  3531. if not is_accelerate_available():
  3532. raise ImportError(
  3533. f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  3534. )
  3535. init_contexts.append(init_empty_weights())
  3536. config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
  3537. if not getattr(config, "_attn_implementation_autoset", False):
  3538. config = cls._autoset_attn_implementation(
  3539. config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
  3540. )
  3541. with ContextManagers(init_contexts):
  3542. # Let's make sure we don't run the init function of buffer modules
  3543. model = cls(config, *model_args, **model_kwargs)
  3544. # make sure we use the model's config since the __init__ call might have copied it
  3545. config = model.config
  3546. # Check first if we are `from_pt`
  3547. if use_keep_in_fp32_modules:
  3548. if is_accelerate_available() and not is_deepspeed_zero3_enabled():
  3549. low_cpu_mem_usage = True
  3550. keep_in_fp32_modules = model._keep_in_fp32_modules
  3551. else:
  3552. keep_in_fp32_modules = []
  3553. if hf_quantizer is not None:
  3554. hf_quantizer.preprocess_model(
  3555. model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
  3556. )
  3557. # We store the original dtype for quantized models as we cannot easily retrieve it
  3558. # once the weights have been quantized
  3559. # Note that once you have loaded a quantized model, you can't change its dtype so this will
  3560. # remain a single source of truth
  3561. config._pre_quantization_dtype = torch_dtype
  3562. if isinstance(device_map, str):
  3563. special_dtypes = {}
  3564. if hf_quantizer is not None:
  3565. special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
  3566. special_dtypes.update(
  3567. {
  3568. name: torch.float32
  3569. for name, _ in model.named_parameters()
  3570. if any(m in name for m in keep_in_fp32_modules)
  3571. }
  3572. )
  3573. target_dtype = torch_dtype
  3574. if hf_quantizer is not None:
  3575. target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
  3576. no_split_modules = model._get_no_split_modules(device_map)
  3577. if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  3578. raise ValueError(
  3579. "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
  3580. "'sequential'."
  3581. )
  3582. device_map_kwargs = {"no_split_module_classes": no_split_modules}
  3583. if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
  3584. device_map_kwargs["special_dtypes"] = special_dtypes
  3585. elif len(special_dtypes) > 0:
  3586. logger.warning(
  3587. "This model has some weights that should be kept in higher precision, you need to upgrade "
  3588. "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
  3589. )
  3590. if device_map != "sequential":
  3591. max_memory = get_balanced_memory(
  3592. model,
  3593. dtype=target_dtype,
  3594. low_zero=(device_map == "balanced_low_0"),
  3595. max_memory=max_memory,
  3596. **device_map_kwargs,
  3597. )
  3598. else:
  3599. max_memory = get_max_memory(max_memory)
  3600. if hf_quantizer is not None:
  3601. max_memory = hf_quantizer.adjust_max_memory(max_memory)
  3602. device_map_kwargs["max_memory"] = max_memory
  3603. # Make sure tied weights are tied before creating the device map.
  3604. model.tie_weights()
  3605. device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
  3606. if hf_quantizer is not None:
  3607. hf_quantizer.validate_environment(device_map=device_map)
  3608. elif device_map is not None:
  3609. model.tie_weights()
  3610. tied_params = find_tied_parameters(model)
  3611. # check if we don't have tied param in different devices
  3612. check_tied_parameters_on_same_device(tied_params, device_map)
  3613. if from_tf:
  3614. if resolved_archive_file.endswith(".index"):
  3615. # Load from a TensorFlow 1.X checkpoint - provided by original authors
  3616. model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
  3617. else:
  3618. # Load from our TensorFlow 2.0 checkpoints
  3619. try:
  3620. from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
  3621. model, loading_info = load_tf2_checkpoint_in_pytorch_model(
  3622. model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True
  3623. )
  3624. except ImportError:
  3625. logger.error(
  3626. "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
  3627. " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation"
  3628. " instructions."
  3629. )
  3630. raise
  3631. elif from_flax:
  3632. try:
  3633. from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
  3634. model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
  3635. except ImportError:
  3636. logger.error(
  3637. "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
  3638. " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for"
  3639. " installation instructions."
  3640. )
  3641. raise
  3642. elif from_pt:
  3643. # restore default dtype
  3644. if dtype_orig is not None:
  3645. torch.set_default_dtype(dtype_orig)
  3646. (
  3647. model,
  3648. missing_keys,
  3649. unexpected_keys,
  3650. mismatched_keys,
  3651. offload_index,
  3652. error_msgs,
  3653. ) = cls._load_pretrained_model(
  3654. model,
  3655. state_dict,
  3656. loaded_state_dict_keys, # XXX: rename?
  3657. resolved_archive_file,
  3658. pretrained_model_name_or_path,
  3659. ignore_mismatched_sizes=ignore_mismatched_sizes,
  3660. sharded_metadata=sharded_metadata,
  3661. _fast_init=_fast_init,
  3662. low_cpu_mem_usage=low_cpu_mem_usage,
  3663. device_map=device_map,
  3664. offload_folder=offload_folder,
  3665. offload_state_dict=offload_state_dict,
  3666. dtype=torch_dtype,
  3667. hf_quantizer=hf_quantizer,
  3668. keep_in_fp32_modules=keep_in_fp32_modules,
  3669. gguf_path=gguf_path,
  3670. weights_only=weights_only,
  3671. )
  3672. # make sure token embedding weights are still tied if needed
  3673. model.tie_weights()
  3674. # Set model in evaluation mode to deactivate DropOut modules by default
  3675. model.eval()
  3676. # If it is a model with generation capabilities, attempt to load the generation config
  3677. if model.can_generate() and generation_config is not None:
  3678. logger.info("The user-defined `generation_config` will be used to override the default generation config.")
  3679. model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
  3680. elif model.can_generate() and pretrained_model_name_or_path is not None:
  3681. try:
  3682. model.generation_config = GenerationConfig.from_pretrained(
  3683. pretrained_model_name_or_path,
  3684. cache_dir=cache_dir,
  3685. force_download=force_download,
  3686. resume_download=resume_download,
  3687. proxies=proxies,
  3688. local_files_only=local_files_only,
  3689. token=token,
  3690. revision=revision,
  3691. subfolder=subfolder,
  3692. _from_auto=from_auto_class,
  3693. _from_pipeline=from_pipeline,
  3694. **kwargs,
  3695. )
  3696. except OSError:
  3697. logger.info(
  3698. "Generation config file not found, using a generation config created from the model config."
  3699. )
  3700. pass
  3701. # Dispatch model with hooks on all devices if necessary
  3702. if device_map is not None:
  3703. device_map_kwargs = {
  3704. "device_map": device_map,
  3705. "offload_dir": offload_folder,
  3706. "offload_index": offload_index,
  3707. "offload_buffers": offload_buffers,
  3708. }
  3709. if "skip_keys" in inspect.signature(dispatch_model).parameters:
  3710. device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
  3711. # For HQQ method we force-set the hooks for single GPU envs
  3712. if (
  3713. "force_hooks" in inspect.signature(dispatch_model).parameters
  3714. and hf_quantizer is not None
  3715. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
  3716. ):
  3717. device_map_kwargs["force_hooks"] = True
  3718. if (
  3719. hf_quantizer is not None
  3720. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
  3721. and isinstance(device_map, dict)
  3722. and ("cpu" in device_map.values() or "disk" in device_map.values())
  3723. ):
  3724. device_map_kwargs["offload_buffers"] = True
  3725. if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
  3726. dispatch_model(model, **device_map_kwargs)
  3727. if hf_quantizer is not None:
  3728. hf_quantizer.postprocess_model(model)
  3729. model.hf_quantizer = hf_quantizer
  3730. if _adapter_model_path is not None:
  3731. model.load_adapter(
  3732. _adapter_model_path,
  3733. adapter_name=adapter_name,
  3734. token=token,
  3735. adapter_kwargs=adapter_kwargs,
  3736. )
  3737. if output_loading_info:
  3738. if loading_info is None:
  3739. loading_info = {
  3740. "missing_keys": missing_keys,
  3741. "unexpected_keys": unexpected_keys,
  3742. "mismatched_keys": mismatched_keys,
  3743. "error_msgs": error_msgs,
  3744. }
  3745. return model, loading_info
  3746. return model
  3747. @classmethod
  3748. def _load_pretrained_model(
  3749. cls,
  3750. model,
  3751. state_dict,
  3752. loaded_keys,
  3753. resolved_archive_file,
  3754. pretrained_model_name_or_path,
  3755. ignore_mismatched_sizes=False,
  3756. sharded_metadata=None,
  3757. _fast_init=True,
  3758. low_cpu_mem_usage=False,
  3759. device_map=None,
  3760. offload_folder=None,
  3761. offload_state_dict=None,
  3762. dtype=None,
  3763. hf_quantizer=None,
  3764. keep_in_fp32_modules=None,
  3765. gguf_path=None,
  3766. weights_only=True,
  3767. ):
  3768. is_safetensors = False
  3769. is_quantized = hf_quantizer is not None
  3770. state_dict_folder = None
  3771. state_dict_index = None
  3772. if device_map is not None and "disk" in device_map.values():
  3773. archive_file = (
  3774. resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
  3775. )
  3776. is_safetensors = archive_file.endswith(".safetensors")
  3777. if offload_folder is None and not is_safetensors:
  3778. raise ValueError(
  3779. "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
  3780. " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
  3781. " offers the weights in this format."
  3782. )
  3783. if offload_folder is not None:
  3784. os.makedirs(offload_folder, exist_ok=True)
  3785. if offload_state_dict is None:
  3786. offload_state_dict = True
  3787. is_sharded_safetensors = is_safetensors and sharded_metadata is not None
  3788. # tie the model weights before retrieving the state_dict
  3789. model.tie_weights()
  3790. # Retrieve missing & unexpected_keys
  3791. model_state_dict = model.state_dict()
  3792. expected_keys = list(model_state_dict.keys())
  3793. prefix = model.base_model_prefix
  3794. if hf_quantizer is not None:
  3795. expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
  3796. def _fix_key(key):
  3797. if "beta" in key:
  3798. return key.replace("beta", "bias")
  3799. if "gamma" in key:
  3800. return key.replace("gamma", "weight")
  3801. # to avoid logging parametrized weight norm renaming
  3802. if hasattr(nn.utils.parametrizations, "weight_norm"):
  3803. if "weight_g" in key:
  3804. return key.replace("weight_g", "parametrizations.weight.original0")
  3805. if "weight_v" in key:
  3806. return key.replace("weight_v", "parametrizations.weight.original1")
  3807. else:
  3808. if "parametrizations.weight.original0" in key:
  3809. return key.replace("parametrizations.weight.original0", "weight_g")
  3810. if "parametrizations.weight.original1" in key:
  3811. return key.replace("parametrizations.weight.original1", "weight_v")
  3812. return key
  3813. original_loaded_keys = loaded_keys
  3814. loaded_keys = [_fix_key(key) for key in loaded_keys]
  3815. if len(prefix) > 0:
  3816. has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
  3817. expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
  3818. else:
  3819. has_prefix_module = False
  3820. expects_prefix_module = False
  3821. # key re-naming operations are never done on the keys
  3822. # that are loaded, but always on the keys of the newly initialized model
  3823. remove_prefix_from_model = not has_prefix_module and expects_prefix_module
  3824. add_prefix_to_model = has_prefix_module and not expects_prefix_module
  3825. if remove_prefix_from_model:
  3826. _prefix = f"{prefix}."
  3827. expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
  3828. expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
  3829. elif add_prefix_to_model:
  3830. expected_keys = [".".join([prefix, s]) for s in expected_keys]
  3831. missing_keys = sorted(set(expected_keys) - set(loaded_keys))
  3832. unexpected_keys = set(loaded_keys) - set(expected_keys)
  3833. # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
  3834. # buffers
  3835. model_buffers = {n for n, _ in model.named_buffers()}
  3836. if remove_prefix_from_model:
  3837. model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers}
  3838. elif add_prefix_to_model:
  3839. model_buffers = {".".join([prefix, key]) for key in model_buffers}
  3840. unexpected_keys = sorted(unexpected_keys - model_buffers)
  3841. model.tie_weights()
  3842. if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
  3843. ptrs = collections.defaultdict(list)
  3844. for name, tensor in model.state_dict().items():
  3845. id_tensor = id_tensor_storage(tensor)
  3846. ptrs[id_tensor].append(name)
  3847. # These are all the pointers of shared tensors.
  3848. tied_params = [names for _, names in ptrs.items() if len(names) > 1]
  3849. else:
  3850. # id function doesn't work for meta tensor so we need this function
  3851. tied_params = find_tied_parameters(model)
  3852. for group in tied_params:
  3853. if remove_prefix_from_model:
  3854. group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group]
  3855. elif add_prefix_to_model:
  3856. group = [".".join([prefix, key]) for key in group]
  3857. missing_in_group = [k for k in missing_keys if k in group]
  3858. if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
  3859. missing_keys = [k for k in missing_keys if k not in missing_in_group]
  3860. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  3861. # the user.
  3862. if cls._keys_to_ignore_on_load_missing is not None:
  3863. for pat in cls._keys_to_ignore_on_load_missing:
  3864. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  3865. if cls._keys_to_ignore_on_load_unexpected is not None:
  3866. for pat in cls._keys_to_ignore_on_load_unexpected:
  3867. unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
  3868. if hf_quantizer is not None:
  3869. missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
  3870. # retrieve weights on meta device and put them back on CPU.
  3871. # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
  3872. if low_cpu_mem_usage:
  3873. for key in missing_keys:
  3874. if key in list(model_state_dict.keys()):
  3875. key = key
  3876. elif f"{prefix}.{key}" in list(model_state_dict.keys()):
  3877. key = f"{prefix}.{key}"
  3878. elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
  3879. key = ".".join(key.split(".")[1:])
  3880. param = model_state_dict[key]
  3881. # upcast in fp32 if any
  3882. target_dtype = dtype
  3883. if (
  3884. keep_in_fp32_modules is not None
  3885. and dtype == torch.float16
  3886. and any(
  3887. module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
  3888. )
  3889. ):
  3890. target_dtype = torch.float32
  3891. if param.device == torch.device("meta"):
  3892. value = torch.empty(*param.size(), dtype=target_dtype)
  3893. if (
  3894. not is_quantized
  3895. or (getattr(hf_quantizer, "requires_parameters_quantization", False))
  3896. or not hf_quantizer.check_quantized_param(
  3897. model, param_value=value, param_name=key, state_dict={}
  3898. )
  3899. ):
  3900. set_module_tensor_to_device(model, key, "cpu", value)
  3901. else:
  3902. hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys)
  3903. # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
  3904. if _fast_init:
  3905. if not ignore_mismatched_sizes:
  3906. if remove_prefix_from_model:
  3907. _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
  3908. elif add_prefix_to_model:
  3909. _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
  3910. else:
  3911. _loaded_keys = loaded_keys
  3912. not_initialized_submodules = set_initialized_submodules(model, _loaded_keys)
  3913. # If we're about to tie the output embeds to the input embeds we don't need to init them
  3914. if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings:
  3915. output_embeddings = model.get_output_embeddings()
  3916. if output_embeddings is not None:
  3917. # Still need to initialize if there is a bias term since biases are not tied.
  3918. if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
  3919. output_embeddings._is_hf_initialized = True
  3920. else:
  3921. not_initialized_submodules = dict(model.named_modules())
  3922. # This will only initialize submodules that are not marked as initialized by the line above.
  3923. if is_deepspeed_zero3_enabled() and not is_quantized:
  3924. import deepspeed
  3925. not_initialized_parameters = list(
  3926. set(
  3927. itertools.chain.from_iterable(
  3928. submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
  3929. )
  3930. )
  3931. )
  3932. with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
  3933. model.apply(model._initialize_weights)
  3934. else:
  3935. model.apply(model._initialize_weights)
  3936. # Set some modules to fp32 if any
  3937. if keep_in_fp32_modules is not None:
  3938. for name, param in model.named_parameters():
  3939. if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
  3940. # param = param.to(torch.float32) does not work here as only in the local scope.
  3941. param.data = param.data.to(torch.float32)
  3942. # Make sure we are able to load base models as well as derived models (with heads)
  3943. start_prefix = ""
  3944. model_to_load = model
  3945. if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
  3946. start_prefix = cls.base_model_prefix + "."
  3947. if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
  3948. model_to_load = getattr(model, cls.base_model_prefix)
  3949. base_model_expected_keys = list(model_to_load.state_dict().keys())
  3950. if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
  3951. raise ValueError(
  3952. "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
  3953. "properly saved?"
  3954. )
  3955. if device_map is not None:
  3956. device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()}
  3957. def _find_mismatched_keys(
  3958. state_dict,
  3959. model_state_dict,
  3960. loaded_keys,
  3961. add_prefix_to_model,
  3962. remove_prefix_from_model,
  3963. ignore_mismatched_sizes,
  3964. ):
  3965. mismatched_keys = []
  3966. if ignore_mismatched_sizes:
  3967. for checkpoint_key in loaded_keys:
  3968. # If the checkpoint is sharded, we may not have the key here.
  3969. if checkpoint_key not in state_dict:
  3970. continue
  3971. model_key = checkpoint_key
  3972. if remove_prefix_from_model:
  3973. # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
  3974. model_key = f"{prefix}.{checkpoint_key}"
  3975. elif add_prefix_to_model:
  3976. # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
  3977. model_key = ".".join(checkpoint_key.split(".")[1:])
  3978. if (
  3979. model_key in model_state_dict
  3980. and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
  3981. ):
  3982. if (
  3983. state_dict[checkpoint_key].shape[-1] == 1
  3984. and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel()
  3985. ):
  3986. # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
  3987. # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights.
  3988. pass
  3989. else:
  3990. mismatched_keys.append(
  3991. (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
  3992. )
  3993. del state_dict[checkpoint_key]
  3994. return mismatched_keys
  3995. if resolved_archive_file is not None:
  3996. folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
  3997. else:
  3998. folder = None
  3999. if device_map is not None and is_safetensors:
  4000. param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
  4001. str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
  4002. if sharded_metadata is None:
  4003. archive_file = (
  4004. resolved_archive_file[0]
  4005. if isinstance(resolved_archive_file, (list, tuple))
  4006. else resolved_archive_file
  4007. )
  4008. weight_map = {p: archive_file for p in original_loaded_keys}
  4009. else:
  4010. weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
  4011. offload_index = {
  4012. p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
  4013. for p, f in weight_map.items()
  4014. if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk"
  4015. }
  4016. else:
  4017. offload_index = None
  4018. if state_dict is not None:
  4019. # Whole checkpoint
  4020. mismatched_keys = _find_mismatched_keys(
  4021. state_dict,
  4022. model_state_dict,
  4023. original_loaded_keys,
  4024. add_prefix_to_model,
  4025. remove_prefix_from_model,
  4026. ignore_mismatched_sizes,
  4027. )
  4028. # For GGUF models `state_dict` is never set to None as the state dict is always small
  4029. if gguf_path:
  4030. error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  4031. model_to_load,
  4032. state_dict,
  4033. start_prefix,
  4034. expected_keys,
  4035. device_map=device_map,
  4036. offload_folder=offload_folder,
  4037. offload_index=offload_index,
  4038. state_dict_folder=state_dict_folder,
  4039. state_dict_index=state_dict_index,
  4040. dtype=dtype,
  4041. hf_quantizer=hf_quantizer,
  4042. is_safetensors=is_safetensors,
  4043. keep_in_fp32_modules=keep_in_fp32_modules,
  4044. unexpected_keys=unexpected_keys,
  4045. )
  4046. else:
  4047. # Sharded checkpoint or whole but low_cpu_mem_usage==True
  4048. assign_to_params_buffers = check_support_param_buffer_assignment(
  4049. model_to_load, state_dict, start_prefix
  4050. )
  4051. error_msgs = _load_state_dict_into_model(
  4052. model_to_load, state_dict, start_prefix, assign_to_params_buffers
  4053. )
  4054. else:
  4055. # This should always be a list but, just to be sure.
  4056. if not isinstance(resolved_archive_file, list):
  4057. resolved_archive_file = [resolved_archive_file]
  4058. error_msgs = []
  4059. mismatched_keys = []
  4060. if not is_safetensors:
  4061. offload_index = {} if device_map is not None and "disk" in device_map.values() else None
  4062. if offload_state_dict:
  4063. state_dict_folder = tempfile.mkdtemp()
  4064. state_dict_index = {}
  4065. else:
  4066. state_dict_folder = None
  4067. state_dict_index = None
  4068. if is_sharded_safetensors:
  4069. disk_only_shard_files = get_disk_only_shard_files(
  4070. device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix
  4071. )
  4072. disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
  4073. else:
  4074. disk_only_shard_files = []
  4075. if len(resolved_archive_file) > 1:
  4076. resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
  4077. assign_to_params_buffers = None
  4078. for shard_file in resolved_archive_file:
  4079. # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
  4080. if shard_file in disk_only_shard_files:
  4081. continue
  4082. map_location = None
  4083. if (
  4084. device_map is not None
  4085. and hf_quantizer is not None
  4086. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
  4087. and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
  4088. ):
  4089. map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
  4090. state_dict = load_state_dict(
  4091. shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
  4092. )
  4093. # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
  4094. # matching the weights in the model.
  4095. mismatched_keys += _find_mismatched_keys(
  4096. state_dict,
  4097. model_state_dict,
  4098. original_loaded_keys,
  4099. add_prefix_to_model,
  4100. remove_prefix_from_model,
  4101. ignore_mismatched_sizes,
  4102. )
  4103. if low_cpu_mem_usage:
  4104. if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
  4105. for key, param in model_to_load.state_dict().items():
  4106. if param.device == torch.device("meta"):
  4107. set_module_tensor_to_device(
  4108. model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
  4109. )
  4110. else:
  4111. new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  4112. model_to_load,
  4113. state_dict,
  4114. start_prefix,
  4115. expected_keys,
  4116. device_map=device_map,
  4117. offload_folder=offload_folder,
  4118. offload_index=offload_index,
  4119. state_dict_folder=state_dict_folder,
  4120. state_dict_index=state_dict_index,
  4121. dtype=dtype,
  4122. hf_quantizer=hf_quantizer,
  4123. is_safetensors=is_safetensors,
  4124. keep_in_fp32_modules=keep_in_fp32_modules,
  4125. unexpected_keys=unexpected_keys,
  4126. )
  4127. error_msgs += new_error_msgs
  4128. else:
  4129. # Sharded checkpoint or whole but low_cpu_mem_usage==True
  4130. if assign_to_params_buffers is None:
  4131. assign_to_params_buffers = check_support_param_buffer_assignment(
  4132. model_to_load, state_dict, start_prefix
  4133. )
  4134. error_msgs += _load_state_dict_into_model(
  4135. model_to_load, state_dict, start_prefix, assign_to_params_buffers
  4136. )
  4137. # force memory release
  4138. del state_dict
  4139. gc.collect()
  4140. if offload_index is not None and len(offload_index) > 0:
  4141. if model != model_to_load:
  4142. # We need to add the prefix of the base model
  4143. prefix = cls.base_model_prefix
  4144. if not is_safetensors:
  4145. for weight_name in offload_index:
  4146. shutil.move(
  4147. os.path.join(offload_folder, f"{weight_name}.dat"),
  4148. os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
  4149. )
  4150. offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
  4151. if not is_safetensors:
  4152. save_offload_index(offload_index, offload_folder)
  4153. offload_index = None
  4154. if offload_state_dict:
  4155. # Load back temporarily offloaded state dict
  4156. load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
  4157. shutil.rmtree(state_dict_folder)
  4158. if len(error_msgs) > 0:
  4159. error_msg = "\n\t".join(error_msgs)
  4160. if "size mismatch" in error_msg:
  4161. error_msg += (
  4162. "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  4163. )
  4164. raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
  4165. if len(unexpected_keys) > 0:
  4166. archs = [] if model.config.architectures is None else model.config.architectures
  4167. warner = logger.warning if model.__class__.__name__ in archs else logger.info
  4168. warner(
  4169. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
  4170. f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
  4171. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  4172. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  4173. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  4174. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  4175. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  4176. )
  4177. else:
  4178. logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
  4179. if len(missing_keys) > 0:
  4180. logger.warning(
  4181. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  4182. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  4183. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  4184. )
  4185. elif len(mismatched_keys) == 0:
  4186. logger.info(
  4187. f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
  4188. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  4189. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  4190. " training."
  4191. )
  4192. if len(mismatched_keys) > 0:
  4193. mismatched_warning = "\n".join(
  4194. [
  4195. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  4196. for key, shape1, shape2 in mismatched_keys
  4197. ]
  4198. )
  4199. logger.warning(
  4200. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  4201. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  4202. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  4203. " to use it for predictions and inference."
  4204. )
  4205. return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
  4206. def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
  4207. module_keys = {".".join(key.split(".")[:-1]) for key in names}
  4208. # torch.nn.ParameterList is a special case where two parameter keywords
  4209. # are appended to the module name, *e.g.* bert.special_embeddings.0
  4210. module_keys = module_keys.union(
  4211. {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
  4212. )
  4213. retrieved_modules = []
  4214. # retrieve all modules that has at least one missing weight name
  4215. for name, module in self.named_modules():
  4216. if remove_prefix:
  4217. _prefix = f"{self.base_model_prefix}."
  4218. name = name[len(_prefix) :] if name.startswith(_prefix) else name
  4219. elif add_prefix:
  4220. name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
  4221. if name in module_keys:
  4222. retrieved_modules.append(module)
  4223. return retrieved_modules
  4224. @staticmethod
  4225. def _load_pretrained_model_low_mem(
  4226. model,
  4227. loaded_state_dict_keys,
  4228. resolved_archive_file,
  4229. start_prefix="",
  4230. hf_quantizer=None,
  4231. pretrained_model_name_or_path=None,
  4232. weights_only=True,
  4233. ):
  4234. """
  4235. This is an experimental function that loads the model using ~1.x model size CPU memory
  4236. Before you call it do:
  4237. 1. save which state_dict keys are available
  4238. 2. drop state_dict before model is created, since the latter takes 1x model size memory
  4239. Here then we continue:
  4240. 3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
  4241. 4. load state_dict 2nd time
  4242. 5. replace the params/buffers from the state_dict
  4243. Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To
  4244. handle bitsandbytes, needs non-empty hf_quantizer argument.
  4245. """
  4246. _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
  4247. state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
  4248. expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
  4249. error_msgs = _load_state_dict_into_meta_model(
  4250. model,
  4251. state_dict,
  4252. start_prefix,
  4253. expected_keys=expected_keys,
  4254. hf_quantizer=hf_quantizer,
  4255. )
  4256. return error_msgs
  4257. @classmethod
  4258. def register_for_auto_class(cls, auto_class="AutoModel"):
  4259. """
  4260. Register this class with a given auto class. This should only be used for custom models as the ones in the
  4261. library are already mapped with an auto class.
  4262. <Tip warning={true}>
  4263. This API is experimental and may have some slight breaking changes in the next releases.
  4264. </Tip>
  4265. Args:
  4266. auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
  4267. The auto class to register this new model with.
  4268. """
  4269. if not isinstance(auto_class, str):
  4270. auto_class = auto_class.__name__
  4271. import transformers.models.auto as auto_module
  4272. if not hasattr(auto_module, auto_class):
  4273. raise ValueError(f"{auto_class} is not a valid auto class.")
  4274. cls._auto_class = auto_class
  4275. def to_bettertransformer(self) -> "PreTrainedModel":
  4276. """
  4277. Converts the model to use [PyTorch's native attention
  4278. implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
  4279. Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
  4280. subset of all Transformers models are supported.
  4281. PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
  4282. tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
  4283. post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
  4284. Returns:
  4285. [`PreTrainedModel`]: The model converted to BetterTransformer.
  4286. """
  4287. if not is_optimum_available():
  4288. raise ImportError("The package `optimum` is required to use Better Transformer.")
  4289. from optimum.version import __version__ as optimum_version
  4290. if version.parse(optimum_version) < version.parse("1.7.0"):
  4291. raise ImportError(
  4292. f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
  4293. )
  4294. from optimum.bettertransformer import BetterTransformer
  4295. return BetterTransformer.transform(self)
  4296. def reverse_bettertransformer(self):
  4297. """
  4298. Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
  4299. used, for example in order to save the model.
  4300. Returns:
  4301. [`PreTrainedModel`]: The model converted back to the original modeling.
  4302. """
  4303. if not is_optimum_available():
  4304. raise ImportError("The package `optimum` is required to use Better Transformer.")
  4305. from optimum.version import __version__ as optimum_version
  4306. if version.parse(optimum_version) < version.parse("1.7.0"):
  4307. raise ImportError(
  4308. f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
  4309. )
  4310. from optimum.bettertransformer import BetterTransformer
  4311. return BetterTransformer.reverse(self)
  4312. def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
  4313. """
  4314. Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
  4315. """
  4316. # Skip the check during tracing.
  4317. if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
  4318. return
  4319. if (attention_mask is not None) or (self.config.pad_token_id is None):
  4320. return
  4321. # Check only the first and last input IDs to reduce overhead.
  4322. if self.config.pad_token_id in input_ids[:, [-1, 0]]:
  4323. warn_string = (
  4324. "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
  4325. "https://huggingface.co/docs/transformers/troubleshooting"
  4326. "#incorrect-output-when-padding-tokens-arent-masked."
  4327. )
  4328. # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
  4329. # attention_mask or not. In this case, we should still show a warning because this is a rare case.
  4330. if (
  4331. (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
  4332. or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
  4333. or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
  4334. ):
  4335. warn_string += (
  4336. f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
  4337. f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
  4338. f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
  4339. )
  4340. logger.warning_once(warn_string)
  4341. @property
  4342. def _is_quantized_training_enabled(self):
  4343. warnings.warn(
  4344. "`_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead",
  4345. FutureWarning,
  4346. )
  4347. if not hasattr(self, "hf_quantizer"):
  4348. return False
  4349. return self.hf_quantizer.is_trainable
  4350. @property
  4351. def loss_function(self):
  4352. if getattr(self.config, "loss_type", None) is not None:
  4353. loss_type = self.config.loss_type
  4354. else:
  4355. loss_type = self.__class__.__name__
  4356. if loss_type not in LOSS_MAPPING:
  4357. loss_groups = f"({'|'.join(LOSS_MAPPING)})"
  4358. loss_type = re.findall(loss_groups, self.__class__.__name__)
  4359. if len(loss_type) > 0:
  4360. loss_type = loss_type[0]
  4361. else:
  4362. loss_type = None
  4363. if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
  4364. logger.warning_once(
  4365. f"`loss_type={loss_type}` was set in the config but it is unrecognised."
  4366. f"Using the default loss: `ForCausalLMLoss`."
  4367. )
  4368. loss_type = "ForCausalLM"
  4369. return LOSS_MAPPING[loss_type]
  4370. PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
  4371. if PreTrainedModel.push_to_hub.__doc__ is not None:
  4372. PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
  4373. object="model", object_class="AutoModel", object_files="model file"
  4374. )
  4375. class PoolerStartLogits(nn.Module):
  4376. """
  4377. Compute SQuAD start logits from sequence hidden states.
  4378. Args:
  4379. config ([`PretrainedConfig`]):
  4380. The config used by the model, will be used to grab the `hidden_size` of the model.
  4381. """
  4382. def __init__(self, config: PretrainedConfig):
  4383. super().__init__()
  4384. self.dense = nn.Linear(config.hidden_size, 1)
  4385. def forward(
  4386. self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
  4387. ) -> torch.FloatTensor:
  4388. """
  4389. Args:
  4390. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  4391. The final hidden states of the model.
  4392. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  4393. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  4394. should be masked.
  4395. Returns:
  4396. `torch.FloatTensor`: The start logits for SQuAD.
  4397. """
  4398. x = self.dense(hidden_states).squeeze(-1)
  4399. if p_mask is not None:
  4400. if get_parameter_dtype(self) == torch.float16:
  4401. x = x * (1 - p_mask) - 65500 * p_mask
  4402. else:
  4403. x = x * (1 - p_mask) - 1e30 * p_mask
  4404. return x
  4405. class PoolerEndLogits(nn.Module):
  4406. """
  4407. Compute SQuAD end logits from sequence hidden states.
  4408. Args:
  4409. config ([`PretrainedConfig`]):
  4410. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  4411. to use.
  4412. """
  4413. def __init__(self, config: PretrainedConfig):
  4414. super().__init__()
  4415. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  4416. self.activation = nn.Tanh()
  4417. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  4418. self.dense_1 = nn.Linear(config.hidden_size, 1)
  4419. def forward(
  4420. self,
  4421. hidden_states: torch.FloatTensor,
  4422. start_states: Optional[torch.FloatTensor] = None,
  4423. start_positions: Optional[torch.LongTensor] = None,
  4424. p_mask: Optional[torch.FloatTensor] = None,
  4425. ) -> torch.FloatTensor:
  4426. """
  4427. Args:
  4428. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  4429. The final hidden states of the model.
  4430. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  4431. The hidden states of the first tokens for the labeled span.
  4432. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4433. The position of the first token for the labeled span.
  4434. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  4435. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  4436. should be masked.
  4437. <Tip>
  4438. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  4439. `start_states`.
  4440. </Tip>
  4441. Returns:
  4442. `torch.FloatTensor`: The end logits for SQuAD.
  4443. """
  4444. assert (
  4445. start_states is not None or start_positions is not None
  4446. ), "One of start_states, start_positions should be not None"
  4447. if start_positions is not None:
  4448. slen, hsz = hidden_states.shape[-2:]
  4449. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  4450. start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
  4451. start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
  4452. x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
  4453. x = self.activation(x)
  4454. x = self.LayerNorm(x)
  4455. x = self.dense_1(x).squeeze(-1)
  4456. if p_mask is not None:
  4457. if get_parameter_dtype(self) == torch.float16:
  4458. x = x * (1 - p_mask) - 65500 * p_mask
  4459. else:
  4460. x = x * (1 - p_mask) - 1e30 * p_mask
  4461. return x
  4462. class PoolerAnswerClass(nn.Module):
  4463. """
  4464. Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
  4465. Args:
  4466. config ([`PretrainedConfig`]):
  4467. The config used by the model, will be used to grab the `hidden_size` of the model.
  4468. """
  4469. def __init__(self, config):
  4470. super().__init__()
  4471. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  4472. self.activation = nn.Tanh()
  4473. self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
  4474. def forward(
  4475. self,
  4476. hidden_states: torch.FloatTensor,
  4477. start_states: Optional[torch.FloatTensor] = None,
  4478. start_positions: Optional[torch.LongTensor] = None,
  4479. cls_index: Optional[torch.LongTensor] = None,
  4480. ) -> torch.FloatTensor:
  4481. """
  4482. Args:
  4483. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  4484. The final hidden states of the model.
  4485. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  4486. The hidden states of the first tokens for the labeled span.
  4487. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4488. The position of the first token for the labeled span.
  4489. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4490. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  4491. <Tip>
  4492. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  4493. `start_states`.
  4494. </Tip>
  4495. Returns:
  4496. `torch.FloatTensor`: The SQuAD 2.0 answer class.
  4497. """
  4498. # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
  4499. hsz = hidden_states.shape[-1]
  4500. assert (
  4501. start_states is not None or start_positions is not None
  4502. ), "One of start_states, start_positions should be not None"
  4503. if start_positions is not None:
  4504. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  4505. start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
  4506. if cls_index is not None:
  4507. cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  4508. cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
  4509. else:
  4510. cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
  4511. x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
  4512. x = self.activation(x)
  4513. x = self.dense_1(x).squeeze(-1)
  4514. return x
  4515. @dataclass
  4516. class SquadHeadOutput(ModelOutput):
  4517. """
  4518. Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`].
  4519. Args:
  4520. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  4521. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  4522. losses.
  4523. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  4524. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  4525. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  4526. Indices for the top config.start_n_top start token possibilities (beam-search).
  4527. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  4528. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  4529. (beam-search).
  4530. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  4531. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  4532. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  4533. Log probabilities for the `is_impossible` label of the answers.
  4534. """
  4535. loss: Optional[torch.FloatTensor] = None
  4536. start_top_log_probs: Optional[torch.FloatTensor] = None
  4537. start_top_index: Optional[torch.LongTensor] = None
  4538. end_top_log_probs: Optional[torch.FloatTensor] = None
  4539. end_top_index: Optional[torch.LongTensor] = None
  4540. cls_logits: Optional[torch.FloatTensor] = None
  4541. class SQuADHead(nn.Module):
  4542. r"""
  4543. A SQuAD head inspired by XLNet.
  4544. Args:
  4545. config ([`PretrainedConfig`]):
  4546. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  4547. to use.
  4548. """
  4549. def __init__(self, config):
  4550. super().__init__()
  4551. self.start_n_top = config.start_n_top
  4552. self.end_n_top = config.end_n_top
  4553. self.start_logits = PoolerStartLogits(config)
  4554. self.end_logits = PoolerEndLogits(config)
  4555. self.answer_class = PoolerAnswerClass(config)
  4556. @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig)
  4557. def forward(
  4558. self,
  4559. hidden_states: torch.FloatTensor,
  4560. start_positions: Optional[torch.LongTensor] = None,
  4561. end_positions: Optional[torch.LongTensor] = None,
  4562. cls_index: Optional[torch.LongTensor] = None,
  4563. is_impossible: Optional[torch.LongTensor] = None,
  4564. p_mask: Optional[torch.FloatTensor] = None,
  4565. return_dict: bool = False,
  4566. ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
  4567. """
  4568. Args:
  4569. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  4570. Final hidden states of the model on the sequence tokens.
  4571. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4572. Positions of the first token for the labeled span.
  4573. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4574. Positions of the last token for the labeled span.
  4575. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4576. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  4577. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  4578. Whether the question has a possible answer in the paragraph or not.
  4579. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  4580. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  4581. should be masked.
  4582. return_dict (`bool`, *optional*, defaults to `False`):
  4583. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  4584. Returns:
  4585. """
  4586. start_logits = self.start_logits(hidden_states, p_mask=p_mask)
  4587. if start_positions is not None and end_positions is not None:
  4588. # If we are on multi-GPU, let's remove the dimension added by batch splitting
  4589. for x in (start_positions, end_positions, cls_index, is_impossible):
  4590. if x is not None and x.dim() > 1:
  4591. x.squeeze_(-1)
  4592. # during training, compute the end logits based on the ground truth of the start position
  4593. end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
  4594. loss_fct = CrossEntropyLoss()
  4595. start_loss = loss_fct(start_logits, start_positions)
  4596. end_loss = loss_fct(end_logits, end_positions)
  4597. total_loss = (start_loss + end_loss) / 2
  4598. if cls_index is not None and is_impossible is not None:
  4599. # Predict answerability from the representation of CLS and START
  4600. cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
  4601. loss_fct_cls = nn.BCEWithLogitsLoss()
  4602. cls_loss = loss_fct_cls(cls_logits, is_impossible)
  4603. # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
  4604. total_loss += cls_loss * 0.5
  4605. return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
  4606. else:
  4607. # during inference, compute the end logits based on beam search
  4608. bsz, slen, hsz = hidden_states.size()
  4609. start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
  4610. start_top_log_probs, start_top_index = torch.topk(
  4611. start_log_probs, self.start_n_top, dim=-1
  4612. ) # shape (bsz, start_n_top)
  4613. start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
  4614. start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
  4615. start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
  4616. hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
  4617. start_states
  4618. ) # shape (bsz, slen, start_n_top, hsz)
  4619. p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
  4620. end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
  4621. end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
  4622. end_top_log_probs, end_top_index = torch.topk(
  4623. end_log_probs, self.end_n_top, dim=1
  4624. ) # shape (bsz, end_n_top, start_n_top)
  4625. end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
  4626. end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
  4627. start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
  4628. cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
  4629. if not return_dict:
  4630. return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
  4631. else:
  4632. return SquadHeadOutput(
  4633. start_top_log_probs=start_top_log_probs,
  4634. start_top_index=start_top_index,
  4635. end_top_log_probs=end_top_log_probs,
  4636. end_top_index=end_top_index,
  4637. cls_logits=cls_logits,
  4638. )
  4639. class SequenceSummary(nn.Module):
  4640. r"""
  4641. Compute a single vector summary of a sequence hidden states.
  4642. Args:
  4643. config ([`PretrainedConfig`]):
  4644. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  4645. config class of your model for the default values it uses):
  4646. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  4647. - `"last"` -- Take the last token hidden state (like XLNet)
  4648. - `"first"` -- Take the first token hidden state (like Bert)
  4649. - `"mean"` -- Take the mean of all tokens hidden states
  4650. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  4651. - `"attn"` -- Not implemented now, use multi-head attention
  4652. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  4653. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  4654. (otherwise to `config.hidden_size`).
  4655. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  4656. another string or `None` will add no activation.
  4657. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  4658. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  4659. """
  4660. def __init__(self, config: PretrainedConfig):
  4661. super().__init__()
  4662. self.summary_type = getattr(config, "summary_type", "last")
  4663. if self.summary_type == "attn":
  4664. # We should use a standard multi-head attention module with absolute positional embedding for that.
  4665. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  4666. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  4667. raise NotImplementedError
  4668. self.summary = Identity()
  4669. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  4670. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  4671. num_classes = config.num_labels
  4672. else:
  4673. num_classes = config.hidden_size
  4674. self.summary = nn.Linear(config.hidden_size, num_classes)
  4675. activation_string = getattr(config, "summary_activation", None)
  4676. self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
  4677. self.first_dropout = Identity()
  4678. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  4679. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  4680. self.last_dropout = Identity()
  4681. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  4682. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  4683. def forward(
  4684. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  4685. ) -> torch.FloatTensor:
  4686. """
  4687. Compute a single vector summary of a sequence hidden states.
  4688. Args:
  4689. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  4690. The hidden states of the last layer.
  4691. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  4692. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  4693. Returns:
  4694. `torch.FloatTensor`: The summary of the sequence hidden states.
  4695. """
  4696. if self.summary_type == "last":
  4697. output = hidden_states[:, -1]
  4698. elif self.summary_type == "first":
  4699. output = hidden_states[:, 0]
  4700. elif self.summary_type == "mean":
  4701. output = hidden_states.mean(dim=1)
  4702. elif self.summary_type == "cls_index":
  4703. if cls_index is None:
  4704. cls_index = torch.full_like(
  4705. hidden_states[..., :1, :],
  4706. hidden_states.shape[-2] - 1,
  4707. dtype=torch.long,
  4708. )
  4709. else:
  4710. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  4711. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  4712. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  4713. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  4714. elif self.summary_type == "attn":
  4715. raise NotImplementedError
  4716. output = self.first_dropout(output)
  4717. output = self.summary(output)
  4718. output = self.activation(output)
  4719. output = self.last_dropout(output)
  4720. return output
  4721. def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
  4722. """
  4723. Recursively unwraps a model from potential containers (as used in distributed training).
  4724. Args:
  4725. model (`torch.nn.Module`): The model to unwrap.
  4726. recursive (`bool`, *optional*, defaults to `False`):
  4727. Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
  4728. recursively, not just the top-level distributed containers.
  4729. """
  4730. # Use accelerate implementation if available (should always be the case when using torch)
  4731. # This is for pytorch, as we also have to handle things like dynamo
  4732. if is_accelerate_available():
  4733. kwargs = {}
  4734. if recursive:
  4735. if not is_accelerate_available("0.29.0"):
  4736. raise RuntimeError(
  4737. "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
  4738. )
  4739. else:
  4740. kwargs["recursive"] = recursive
  4741. return extract_model_from_parallel(model, **kwargs)
  4742. else:
  4743. # since there could be multiple levels of wrapping, unwrap recursively
  4744. if hasattr(model, "module"):
  4745. return unwrap_model(model.module)
  4746. else:
  4747. return model
  4748. def expand_device_map(device_map, param_names, start_prefix):
  4749. """
  4750. Expand a device map to return the correspondance parameter name to device.
  4751. """
  4752. new_device_map = {}
  4753. param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)]
  4754. for module, device in device_map.items():
  4755. new_device_map.update(
  4756. {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
  4757. )
  4758. return new_device_map
  4759. def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
  4760. """
  4761. Returns the list of shard files containing only weights offloaded to disk.
  4762. """
  4763. weight_map = {
  4764. p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix)
  4765. }
  4766. files_content = collections.defaultdict(list)
  4767. for weight_name, filename in weight_map.items():
  4768. while len(weight_name) > 0 and weight_name not in device_map:
  4769. weight_name = ".".join(weight_name.split(".")[:-1])
  4770. files_content[filename].append(device_map[weight_name])
  4771. return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]