modeling_tf_utils.py 163 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF general model utils."""
  17. from __future__ import annotations
  18. import functools
  19. import gc
  20. import inspect
  21. import json
  22. import os
  23. import pickle
  24. import re
  25. import warnings
  26. from collections.abc import Mapping
  27. from pathlib import Path
  28. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
  29. import h5py
  30. import numpy as np
  31. import tensorflow as tf
  32. from packaging.version import parse
  33. from . import DataCollatorWithPadding, DefaultDataCollator
  34. from .activations_tf import get_tf_activation
  35. from .configuration_utils import PretrainedConfig
  36. from .dynamic_module_utils import custom_object_save
  37. from .generation import GenerationConfig, TFGenerationMixin
  38. from .tf_utils import (
  39. convert_batch_encoding,
  40. expand_1d,
  41. load_attributes_from_hdf5_group,
  42. save_attributes_to_hdf5_group,
  43. shape_list,
  44. )
  45. from .utils import (
  46. SAFE_WEIGHTS_INDEX_NAME,
  47. SAFE_WEIGHTS_NAME,
  48. TF2_WEIGHTS_INDEX_NAME,
  49. TF2_WEIGHTS_NAME,
  50. TF_WEIGHTS_NAME,
  51. WEIGHTS_INDEX_NAME,
  52. WEIGHTS_NAME,
  53. ModelOutput,
  54. PushToHubMixin,
  55. cached_file,
  56. download_url,
  57. find_labels,
  58. has_file,
  59. is_offline_mode,
  60. is_remote_url,
  61. is_safetensors_available,
  62. is_tf_symbolic_tensor,
  63. logging,
  64. requires_backends,
  65. working_or_temp_dir,
  66. )
  67. from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
  68. if is_safetensors_available():
  69. from safetensors import safe_open
  70. from safetensors.tensorflow import save_file as safe_save_file
  71. if TYPE_CHECKING:
  72. from . import PreTrainedTokenizerBase
  73. logger = logging.get_logger(__name__)
  74. if "TF_USE_LEGACY_KERAS" not in os.environ:
  75. os.environ["TF_USE_LEGACY_KERAS"] = "1" # Compatibility fix to make sure tf.keras stays at Keras 2
  76. elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
  77. logger.warning(
  78. "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
  79. "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
  80. )
  81. try:
  82. import tf_keras as keras
  83. from tf_keras import backend as K
  84. except (ModuleNotFoundError, ImportError):
  85. import keras
  86. from keras import backend as K
  87. if parse(keras.__version__).major > 2:
  88. raise ValueError(
  89. "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
  90. "Transformers. Please install the backwards-compatible tf-keras package with "
  91. "`pip install tf-keras`."
  92. )
  93. tf_logger = tf.get_logger()
  94. TFModelInputType = Union[
  95. List[tf.Tensor],
  96. List[np.ndarray],
  97. Dict[str, tf.Tensor],
  98. Dict[str, np.ndarray],
  99. tf.Tensor,
  100. np.ndarray,
  101. ]
  102. def dummy_loss(y_true, y_pred):
  103. if y_pred.shape.rank <= 1:
  104. return y_pred
  105. else:
  106. reduction_axes = list(range(1, y_pred.shape.rank))
  107. return tf.reduce_mean(y_pred, axis=reduction_axes)
  108. class TFModelUtilsMixin:
  109. """
  110. A few utilities for `keras.Model`, to be used as a mixin.
  111. """
  112. def num_parameters(self, only_trainable: bool = False) -> int:
  113. """
  114. Get the number of (optionally, trainable) parameters in the model.
  115. Args:
  116. only_trainable (`bool`, *optional*, defaults to `False`):
  117. Whether or not to return only the number of trainable parameters
  118. Returns:
  119. `int`: The number of parameters.
  120. """
  121. if only_trainable:
  122. return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
  123. else:
  124. return self.count_params()
  125. def keras_serializable(cls):
  126. """
  127. Decorate a Keras Layer class to support Keras serialization.
  128. This is done by:
  129. 1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
  130. serialization time.
  131. 2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
  132. convert it to a config object for the actual layer initializer.
  133. 3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
  134. need to be supplied in `custom_objects` in the call to `keras.models.load_model`.
  135. Args:
  136. cls (a `keras.layers.Layers subclass`):
  137. Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its
  138. initializer.
  139. Returns:
  140. The same class object, with modifications for Keras deserialization.
  141. """
  142. initializer = cls.__init__
  143. config_class = getattr(cls, "config_class", None)
  144. if config_class is None:
  145. raise AttributeError("Must set `config_class` to use @keras_serializable")
  146. @functools.wraps(initializer)
  147. def wrapped_init(self, *args, **kwargs):
  148. config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)
  149. if isinstance(config, dict):
  150. config = config_class.from_dict(config)
  151. initializer(self, config, *args, **kwargs)
  152. elif isinstance(config, PretrainedConfig):
  153. if len(args) > 0:
  154. initializer(self, *args, **kwargs)
  155. else:
  156. initializer(self, config, *args, **kwargs)
  157. else:
  158. raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")
  159. self._config = config
  160. self._kwargs = kwargs
  161. cls.__init__ = wrapped_init
  162. if not hasattr(cls, "get_config"):
  163. raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses")
  164. if hasattr(cls.get_config, "_is_default"):
  165. def get_config(self):
  166. cfg = super(cls, self).get_config()
  167. cfg["config"] = self._config.to_dict()
  168. cfg.update(self._kwargs)
  169. return cfg
  170. cls.get_config = get_config
  171. cls._keras_serializable = True
  172. if hasattr(keras.utils, "register_keras_serializable"):
  173. cls = keras.utils.register_keras_serializable()(cls)
  174. return cls
  175. class TFCausalLanguageModelingLoss:
  176. """
  177. Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.
  178. <Tip>
  179. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  180. </Tip>
  181. """
  182. def hf_compute_loss(self, labels, logits):
  183. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  184. if self.config.tf_legacy_loss:
  185. # make sure only labels that are not equal to -100 affect the loss
  186. active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
  187. reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  188. labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
  189. return loss_fn(labels, reduced_logits)
  190. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  191. unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
  192. # make sure only labels that are not equal to -100 affect the loss
  193. loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
  194. masked_loss = unmasked_loss * loss_mask
  195. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
  196. return tf.reshape(reduced_masked_loss, (1,))
  197. class TFQuestionAnsweringLoss:
  198. """
  199. Loss function suitable for question answering.
  200. """
  201. def hf_compute_loss(self, labels, logits):
  202. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  203. start_loss = loss_fn(labels["start_position"], logits[0])
  204. end_loss = loss_fn(labels["end_position"], logits[1])
  205. return (start_loss + end_loss) / 2.0
  206. class TFTokenClassificationLoss:
  207. """
  208. Loss function suitable for token classification.
  209. <Tip>
  210. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  211. </Tip>
  212. """
  213. def hf_compute_loss(self, labels, logits):
  214. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  215. if tf.executing_eagerly(): # Data-dependent conditionals are forbidden in XLA
  216. if tf.math.reduce_any(labels == -1):
  217. tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
  218. if self.config.tf_legacy_loss:
  219. # make sure only labels that are not equal to -100
  220. # are taken into account as loss
  221. if tf.math.reduce_any(labels == -1):
  222. tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
  223. active_loss = tf.reshape(labels, (-1,)) != -1
  224. else:
  225. active_loss = tf.reshape(labels, (-1,)) != -100
  226. reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  227. labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
  228. return loss_fn(labels, reduced_logits)
  229. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  230. unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
  231. # make sure only labels that are not equal to -100 or -1
  232. # are taken into account as loss
  233. loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
  234. # Avoid possible division by zero later
  235. # Masked positions will have a loss of NaN because -100 and -1 are not valid labels
  236. masked_loss = unmasked_loss * loss_mask
  237. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
  238. return tf.reshape(reduced_masked_loss, (1,))
  239. class TFSequenceClassificationLoss:
  240. """
  241. Loss function suitable for sequence classification.
  242. """
  243. def hf_compute_loss(self, labels, logits):
  244. if logits.shape.rank == 1 or logits.shape[1] == 1:
  245. loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE)
  246. if labels.shape.rank == 1:
  247. # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
  248. labels = tf.expand_dims(labels, axis=-1)
  249. else:
  250. loss_fn = keras.losses.SparseCategoricalCrossentropy(
  251. from_logits=True, reduction=keras.losses.Reduction.NONE
  252. )
  253. return loss_fn(labels, logits)
  254. class TFMultipleChoiceLoss:
  255. """Loss function suitable for multiple choice tasks."""
  256. def hf_compute_loss(self, labels, logits):
  257. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  258. return loss_fn(labels, logits)
  259. class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
  260. """
  261. Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
  262. <Tip>
  263. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  264. </Tip>
  265. """
  266. class TFNextSentencePredictionLoss:
  267. """
  268. Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
  269. <Tip>
  270. Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
  271. </Tip>
  272. """
  273. def hf_compute_loss(self, labels, logits):
  274. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  275. if self.config.tf_legacy_loss:
  276. # make sure only labels that are not equal to -100
  277. # are taken into account as loss
  278. next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
  279. next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
  280. next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
  281. return loss_fn(next_sentence_label, next_sentence_reduced_logits)
  282. # make sure only labels that are not equal to -100
  283. # are taken into account as loss
  284. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  285. unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
  286. ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
  287. # Just zero out samples where label is -100, no reduction
  288. masked_ns_loss = unmasked_ns_loss * ns_loss_mask
  289. return masked_ns_loss
  290. def booleans_processing(config, **kwargs):
  291. """
  292. Process the input booleans of each model.
  293. Args:
  294. config ([`PretrainedConfig`]):
  295. The config of the running model.
  296. **kwargs:
  297. The boolean parameters
  298. Returns:
  299. A dictionary with the proper values for each boolean
  300. """
  301. final_booleans = {}
  302. # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
  303. # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
  304. if "output_attentions" in kwargs:
  305. final_booleans["output_attentions"] = (
  306. kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
  307. )
  308. final_booleans["output_hidden_states"] = (
  309. kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
  310. )
  311. final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
  312. if "use_cache" in kwargs:
  313. final_booleans["use_cache"] = (
  314. kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
  315. )
  316. return final_booleans
  317. def unpack_inputs(func):
  318. """
  319. Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
  320. downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
  321. (common case in Keras).
  322. Args:
  323. func (`callable`):
  324. The callable function of the TensorFlow model.
  325. Returns:
  326. A callable that wraps the original `func` with the behavior described above.
  327. """
  328. original_signature = inspect.signature(func)
  329. @functools.wraps(func)
  330. def run_call_with_unpacked_inputs(self, *args, **kwargs):
  331. # isolates the actual `**kwargs` for the decorated function
  332. kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
  333. fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
  334. fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
  335. # move any arg into kwargs, if they exist
  336. fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
  337. # Encoder Decoder models delegate the application of the configuration options to their inner models.
  338. if "EncoderDecoder" in self.__class__.__name__:
  339. config = None
  340. else:
  341. config = self.config
  342. unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
  343. return func(self, **unpacked_inputs)
  344. # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
  345. # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
  346. # Keras would attempt to check the first argument against the literal signature of the wrapper.
  347. run_call_with_unpacked_inputs.__signature__ = original_signature
  348. return run_call_with_unpacked_inputs
  349. def input_processing(func, config, **kwargs):
  350. """
  351. Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
  352. has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32',
  353. name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
  354. Args:
  355. func (`callable`):
  356. The callable function of the TensorFlow model.
  357. config ([`PretrainedConfig`]):
  358. The config of the running model.
  359. **kwargs:
  360. The inputs of the model.
  361. Returns:
  362. Two lists, one for the missing layers, and another one for the unexpected layers.
  363. """
  364. signature = dict(inspect.signature(func).parameters)
  365. has_kwargs = bool(signature.pop("kwargs", None))
  366. signature.pop("self", None)
  367. parameter_names = list(signature.keys())
  368. main_input_name = parameter_names[0]
  369. main_input = kwargs.pop(main_input_name, None)
  370. output = {}
  371. allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
  372. if "inputs" in kwargs["kwargs_call"]:
  373. warnings.warn(
  374. "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
  375. FutureWarning,
  376. )
  377. output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
  378. if "decoder_cached_states" in kwargs["kwargs_call"]:
  379. warnings.warn(
  380. "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
  381. " `past_key_values` instead.",
  382. FutureWarning,
  383. )
  384. output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
  385. if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
  386. warnings.warn(
  387. "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
  388. " instead.",
  389. FutureWarning,
  390. )
  391. kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
  392. elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
  393. kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
  394. if has_kwargs:
  395. output["kwargs"] = kwargs.pop("kwargs_call", {})
  396. else:
  397. if len(kwargs["kwargs_call"]) > 0:
  398. raise ValueError(
  399. "The following keyword arguments are not supported by this model:"
  400. f" {list(kwargs['kwargs_call'].keys())}."
  401. )
  402. kwargs.pop("kwargs_call")
  403. for k, v in kwargs.items():
  404. if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
  405. output[k] = v
  406. else:
  407. raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
  408. if isinstance(main_input, (tuple, list)):
  409. for i, input in enumerate(main_input):
  410. # EagerTensors don't allow to use the .name property so we check for a real Tensor
  411. if is_tf_symbolic_tensor(input):
  412. # Tensor names have always the pattern `name:id` then we check only the
  413. # `name` part
  414. tensor_name = input.name.split(":")[0]
  415. if tensor_name in parameter_names:
  416. output[tensor_name] = input
  417. else:
  418. output[parameter_names[i]] = input
  419. elif isinstance(input, allowed_types) or input is None:
  420. output[parameter_names[i]] = input
  421. else:
  422. raise ValueError(
  423. f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
  424. f" {parameter_names[i]}."
  425. )
  426. elif isinstance(main_input, Mapping):
  427. if "inputs" in main_input:
  428. warnings.warn(
  429. "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
  430. " instead.",
  431. FutureWarning,
  432. )
  433. output["input_ids"] = main_input.pop("inputs")
  434. if "decoder_cached_states" in main_input:
  435. warnings.warn(
  436. "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
  437. " `past_key_values` instead.",
  438. FutureWarning,
  439. )
  440. output["past_key_values"] = main_input.pop("decoder_cached_states")
  441. for k, v in dict(main_input).items():
  442. if isinstance(v, allowed_types) or v is None:
  443. output[k] = v
  444. elif k not in parameter_names and "args" not in parameter_names:
  445. logger.warning(
  446. f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
  447. )
  448. continue
  449. else:
  450. raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
  451. else:
  452. if tf.is_tensor(main_input) or main_input is None:
  453. output[main_input_name] = main_input
  454. else:
  455. raise ValueError(
  456. f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
  457. f" {main_input_name}."
  458. )
  459. # Populates any unspecified argument with their default value, according to the signature.
  460. for name in parameter_names:
  461. if name not in list(output.keys()) and name != "args":
  462. output[name] = kwargs.pop(name, signature[name].default)
  463. # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
  464. # So to respect the proper output we have to add this exception
  465. if "args" in output:
  466. if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
  467. tensor_name = output["args"].name.split(":")[0]
  468. output[tensor_name] = output["args"]
  469. else:
  470. # `args` in this case is always the first parameter, then `input_ids`
  471. output["input_ids"] = output["args"]
  472. del output["args"]
  473. if "kwargs" in output:
  474. del output["kwargs"]
  475. cast_output = {}
  476. for key, val in output.items():
  477. if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
  478. cast_output[key] = tf.cast(val, tf.int32)
  479. elif isinstance(val, np.ndarray) and val.dtype == np.int64:
  480. cast_output[key] = val.astype(np.int32)
  481. else:
  482. cast_output[key] = val
  483. output = cast_output
  484. del cast_output
  485. if config is not None:
  486. boolean_dict = {
  487. k: v
  488. for k, v in output.items()
  489. if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
  490. }
  491. output.update(
  492. booleans_processing(
  493. config=config,
  494. **boolean_dict,
  495. )
  496. )
  497. return output
  498. def dtype_byte_size(dtype):
  499. """
  500. Returns the size (in bytes) occupied by one parameter of type `dtype`.
  501. Example:
  502. ```py
  503. >>> dtype_byte_size(tf.float32)
  504. 4
  505. ```
  506. """
  507. if dtype == tf.bool:
  508. return 1 / 8
  509. bit_search = re.search(r"[^\d](\d+)$", dtype.name)
  510. if bit_search is None:
  511. raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
  512. bit_size = int(bit_search.groups()[0])
  513. return bit_size // 8
  514. def strip_model_name_and_prefix(name, _prefix=None):
  515. if _prefix is not None and name.startswith(_prefix):
  516. name = name[len(_prefix) :]
  517. if name.startswith("/"):
  518. name = name[1:]
  519. if "model." not in name and len(name.split("/")) > 1:
  520. name = "/".join(name.split("/")[1:])
  521. return name
  522. def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
  523. """
  524. Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
  525. given size.
  526. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
  527. optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
  528. limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
  529. [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
  530. <Tip warning={true}>
  531. If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
  532. have a size greater than `max_shard_size`.
  533. </Tip>
  534. Args:
  535. weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
  536. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  537. The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
  538. (like `"5MB"`).
  539. """
  540. max_shard_size = convert_file_size_to_int(max_shard_size)
  541. sharded_state_dicts = []
  542. current_block = []
  543. current_block_size = 0
  544. total_size = 0
  545. for item in weights:
  546. weight_size = item.numpy().size * dtype_byte_size(item.dtype)
  547. # If this weight is going to tip up over the maximal size, we split.
  548. if current_block_size + weight_size > max_shard_size:
  549. sharded_state_dicts.append(current_block)
  550. current_block = []
  551. current_block_size = 0
  552. current_block.append(item)
  553. current_block_size += weight_size
  554. total_size += weight_size
  555. # Add the last block
  556. sharded_state_dicts.append(current_block)
  557. # If we only have one shard, we return it
  558. if len(sharded_state_dicts) == 1:
  559. return {weights_name: sharded_state_dicts[0]}, None
  560. # Otherwise, let's build the index
  561. weight_map = {}
  562. shards = {}
  563. for idx, shard in enumerate(sharded_state_dicts):
  564. shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
  565. shard_file = shard_file.replace(
  566. ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
  567. )
  568. shards[shard_file] = shard
  569. for weight in shard:
  570. weight_name = weight.name
  571. weight_map[weight_name] = shard_file
  572. # Add the metadata
  573. metadata = {"total_size": total_size}
  574. index = {"metadata": metadata, "weight_map": weight_map}
  575. return shards, index
  576. def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
  577. """
  578. This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
  579. the TF weights from the shard file accordingly to their names and shapes.
  580. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  581. loaded in the model.
  582. Args:
  583. model (`keras.models.Model`): The model in which to load the checkpoint.
  584. shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
  585. ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
  586. Whether or not to ignore the mismatch between the sizes
  587. strict (`bool`, *optional*, defaults to `True`):
  588. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  589. Returns:
  590. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  591. mismatched layers.
  592. """
  593. # Load the index
  594. unexpected_keys = set()
  595. saved_keys = set()
  596. mismatched_keys = set()
  597. # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
  598. # the weight, we have to get rid of the first prefix of the name of the layer.
  599. model_keys = set()
  600. model_layer_map = {}
  601. for i, k in enumerate(model.weights):
  602. layer_name = k.name
  603. if _prefix is not None and layer_name.startswith(_prefix):
  604. layer_name = layer_name[len(_prefix) :]
  605. layer_name = layer_name.lstrip("/")
  606. if not ("model." in layer_name or len(layer_name.split("/")) == 1):
  607. layer_name = "/".join(layer_name.split("/")[1:])
  608. model_keys.add(layer_name)
  609. model_layer_map[layer_name] = i
  610. for shard_file in shard_files:
  611. saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
  612. model,
  613. model_layer_map,
  614. shard_file,
  615. ignore_mismatched_sizes=ignore_mismatched_sizes,
  616. _prefix=_prefix,
  617. )
  618. saved_keys.update(saved_weight_names_set)
  619. unexpected_keys.update(unexpected_keys_set)
  620. mismatched_keys.update(mismatched_keys_set)
  621. gc.collect()
  622. missing_keys = model_keys - saved_keys
  623. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  624. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  625. if len(missing_keys) > 0:
  626. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  627. error_message += f"\nMissing key(s): {str_missing_keys}."
  628. if len(unexpected_keys) > 0:
  629. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  630. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  631. raise RuntimeError(error_message)
  632. return missing_keys, unexpected_keys, mismatched_keys
  633. def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  634. """
  635. Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
  636. Handles missing keys and unexpected keys.
  637. Args:
  638. model (`keras.models.Model`): Model in which the weights are loaded
  639. model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model.
  640. resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
  641. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
  642. Returns:
  643. `keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
  644. shard file), one for the mismatched layers, and another one for the unexpected layers.
  645. """
  646. saved_weight_names_set = set()
  647. saved_weights = {}
  648. mismatched_keys = set()
  649. unexpected_keys = set()
  650. # Read the H5 file
  651. try:
  652. with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
  653. # Retrieve the name of each layer from the H5 file
  654. saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
  655. weight_value_tuples = []
  656. # Compute missing and unexpected sub layers
  657. # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
  658. for layer_name in saved_h5_model_layers_name:
  659. h5_layer_object = sharded_checkpoint_file[layer_name]
  660. saved_weights[layer_name] = np.asarray(h5_layer_object)
  661. saved_weight_names_set.add(layer_name)
  662. if layer_name not in model_layer_map:
  663. unexpected_keys.add(layer_name)
  664. else:
  665. symbolic_weight = model.weights[model_layer_map[layer_name]]
  666. saved_weight_value = saved_weights[layer_name]
  667. # If the current weight is found
  668. if saved_weight_value is not None:
  669. # Check if the shape of the current weight and the one from the H5 file are different
  670. if K.int_shape(symbolic_weight) != saved_weight_value.shape:
  671. # If yes we reshape the weight from the H5 file accordingly to the current weight
  672. # If the two shapes are not compatible we raise an issue
  673. try:
  674. array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
  675. except ValueError as e:
  676. if ignore_mismatched_sizes:
  677. mismatched_keys.add(
  678. (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
  679. )
  680. continue
  681. else:
  682. raise e
  683. else:
  684. array = saved_weight_value
  685. # We create the tuple that will be loaded and add it to the final list
  686. weight_value_tuples.append((symbolic_weight, array))
  687. K.batch_set_value(weight_value_tuples)
  688. return saved_weight_names_set, unexpected_keys, mismatched_keys
  689. except Exception as e:
  690. try:
  691. with open(resolved_archive_file) as f:
  692. if f.read().startswith("version"):
  693. raise OSError(
  694. "You seem to have cloned a repository without having git-lfs installed. Please install "
  695. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  696. "you cloned."
  697. )
  698. else:
  699. raise ValueError(
  700. f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
  701. " model. Make sure you have saved the model properly."
  702. ) from e
  703. except (UnicodeDecodeError, ValueError):
  704. raise OSError(
  705. f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
  706. f"at '{resolved_archive_file}'. "
  707. "If you tried to load a TF model from a sharded checkpoint, you should try converting the model "
  708. "by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
  709. )
  710. def load_tf_sharded_weights_from_safetensors(
  711. model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
  712. ):
  713. """
  714. This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
  715. Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
  716. shapes.
  717. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  718. loaded in the model.
  719. Args:
  720. model (`keras.models.Model`): The model in which to load the checkpoint.
  721. shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
  722. ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
  723. Whether or not to ignore the mismatch between the sizes
  724. strict (`bool`, *optional*, defaults to `True`):
  725. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  726. Returns:
  727. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  728. mismatched layers.
  729. """
  730. # Load the index
  731. unexpected_keys = set()
  732. all_missing_keys = []
  733. mismatched_keys = set()
  734. for shard_file in shard_files:
  735. missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
  736. model,
  737. shard_file,
  738. ignore_mismatched_sizes=ignore_mismatched_sizes,
  739. _prefix=_prefix,
  740. )
  741. all_missing_keys.append(set(missing_layers))
  742. unexpected_keys.update(unexpected_layers)
  743. mismatched_keys.update(mismatched_layers)
  744. gc.collect()
  745. missing_keys = set.intersection(*all_missing_keys)
  746. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  747. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  748. if len(missing_keys) > 0:
  749. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  750. error_message += f"\nMissing key(s): {str_missing_keys}."
  751. if len(unexpected_keys) > 0:
  752. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  753. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  754. raise RuntimeError(error_message)
  755. return missing_keys, unexpected_keys, mismatched_keys
  756. def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  757. """
  758. Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
  759. shapes.
  760. Args:
  761. model (`keras.models.Model`):
  762. The model to load the weights into.
  763. resolved_archive_file (`str`):
  764. The location of the H5 file.
  765. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  766. Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
  767. Returns:
  768. Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
  769. mismatched layers.
  770. """
  771. if resolved_archive_file.endswith(".safetensors"):
  772. load_function = load_tf_weights_from_safetensors
  773. else:
  774. load_function = load_tf_weights_from_h5
  775. return load_function(
  776. model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
  777. )
  778. def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  779. mismatched_layers = []
  780. # Read the H5 file
  781. with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
  782. # Retrieve the name of each layer from the H5 file
  783. saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
  784. # Find the missing layers from the high level list of layers
  785. missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
  786. # Find the unexpected layers from the high level list of layers
  787. unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
  788. saved_weight_names_set = set()
  789. symbolic_weights_names = set()
  790. weight_value_tuples = []
  791. # Compute missing and unexpected sub layers
  792. # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
  793. for layer in model.layers:
  794. # if layer_name from the H5 file belongs to the layers from the instantiated model
  795. if layer.name in saved_h5_model_layers_name:
  796. # Get the H5 layer object from its name
  797. h5_layer_object = sharded_checkpoint_file[layer.name]
  798. # Get all the weights as a list from the layer object
  799. symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
  800. saved_weights = {}
  801. # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
  802. # And a set with only the names
  803. for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
  804. # TF names always start with the model name so we ignore it
  805. name = "/".join(weight_name.split("/")[1:])
  806. if _prefix is not None:
  807. name = _prefix + "/" + name
  808. saved_weights[name] = np.asarray(h5_layer_object[weight_name])
  809. # Add the updated name to the final list for computing missing/unexpected values
  810. saved_weight_names_set.add(name)
  811. # Loop over each weights from the instantiated model and compare with the weights from the H5 file
  812. for symbolic_weight in symbolic_weights:
  813. # TF names always start with the model name so we ignore it
  814. if _prefix is not None:
  815. delimeter = len(_prefix.split("/"))
  816. symbolic_weight_name = "/".join(
  817. symbolic_weight.name.split("/")[:delimeter]
  818. + symbolic_weight.name.split("/")[delimeter + 1 :]
  819. )
  820. else:
  821. symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
  822. # here we check if the current weight is among the weights from the H5 file
  823. # If yes, get the weight_value of the corresponding weight from the H5 file
  824. # If not, make the value to None
  825. saved_weight_value = saved_weights.get(symbolic_weight_name, None)
  826. # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
  827. # `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
  828. if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
  829. symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
  830. saved_weight_value = saved_weights.get(symbolic_weight_name, None)
  831. # Add the updated name to the final list for computing missing/unexpected values
  832. symbolic_weights_names.add(symbolic_weight_name)
  833. # If the current weight is found
  834. if saved_weight_value is not None:
  835. # Check if the shape of the current weight and the one from the H5 file are different
  836. if K.int_shape(symbolic_weight) != saved_weight_value.shape:
  837. # If yes we reshape the weight from the H5 file accordingly to the current weight
  838. # If the two shapes are not compatible we raise an issue
  839. try:
  840. array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
  841. except ValueError as e:
  842. if ignore_mismatched_sizes:
  843. mismatched_layers.append(
  844. (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
  845. )
  846. continue
  847. else:
  848. raise e
  849. else:
  850. array = saved_weight_value
  851. # We create the tuple that will be loaded and add it to the final list
  852. weight_value_tuples.append((symbolic_weight, array))
  853. # Load all the weights
  854. K.batch_set_value(weight_value_tuples)
  855. # Compute the missing and unexpected layers
  856. missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
  857. unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
  858. return missing_layers, unexpected_layers, mismatched_layers
  859. def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
  860. # Read the safetensors file
  861. with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
  862. mismatched_layers = []
  863. weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
  864. loaded_weight_names = list(safetensors_archive.keys())
  865. # Find the missing layers from the high level list of layers
  866. missing_layers = list(set(weight_names) - set(loaded_weight_names))
  867. # Find the unexpected layers from the high level list of layers
  868. unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
  869. for weight in model.weights:
  870. weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
  871. if weight_name in loaded_weight_names:
  872. weight_value = safetensors_archive.get_tensor(weight_name)
  873. # Check if the shape of the current weight and the one from the H5 file are different
  874. if K.int_shape(weight) != weight_value.shape:
  875. # If yes we reshape the weight from the H5 file accordingly to the current weight
  876. # If the two shapes are not compatible we raise an issue
  877. try:
  878. weight_value = tf.reshape(weight_value, K.int_shape(weight))
  879. except (ValueError, tf.errors.InvalidArgumentError) as e:
  880. if ignore_mismatched_sizes:
  881. mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
  882. continue
  883. else:
  884. raise e
  885. K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
  886. return missing_layers, unexpected_layers, mismatched_layers
  887. def init_copy_embeddings(old_embeddings, new_num_tokens):
  888. r"""
  889. This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
  890. new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
  891. kept or not. Example:
  892. - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]
  893. - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
  894. - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]
  895. - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
  896. """
  897. old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
  898. size_diff = new_num_tokens - old_num_tokens
  899. # initialize new embeddings
  900. # Copy token embeddings from the previous ones
  901. if tf.math.greater(size_diff, 0):
  902. # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
  903. # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
  904. # embeddings
  905. current_weights = tf.pad(
  906. old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
  907. )
  908. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  909. mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
  910. mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
  911. else:
  912. # if the new size if lower than the old one, we take the current embeddings until the new size
  913. current_weights = tf.slice(
  914. old_embeddings.value(),
  915. tf.convert_to_tensor([0, 0]),
  916. tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
  917. )
  918. mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)
  919. return mask, current_weights
  920. class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
  921. r"""
  922. Base class for all TF models.
  923. [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
  924. downloading and saving models as well as a few methods common to all models to:
  925. - resize the input embeddings,
  926. - prune heads in the self-attention heads.
  927. Class attributes (overridden by derived classes):
  928. - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
  929. for this model architecture.
  930. - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  931. classes of the same architecture adding modules on top of the base model.
  932. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  933. models, `pixel_values` for vision models and `input_values` for speech models).
  934. """
  935. config_class = None
  936. base_model_prefix = ""
  937. main_input_name = "input_ids"
  938. _auto_class = None
  939. _using_dummy_loss = None
  940. _label_to_output_map = None
  941. # a list of re pattern of tensor names to ignore from the model when loading the model weights
  942. # (and avoid unnecessary warnings).
  943. _keys_to_ignore_on_load_missing = None
  944. # a list of re pattern of tensor names to ignore from the weights when loading the model weights
  945. # (and avoid unnecessary warnings).
  946. _keys_to_ignore_on_load_unexpected = None
  947. _requires_load_weight_prefix = False
  948. @property
  949. def dummy_inputs(self) -> Dict[str, tf.Tensor]:
  950. """
  951. Dummy inputs to build the network.
  952. Returns:
  953. `Dict[str, tf.Tensor]`: The dummy inputs.
  954. """
  955. dummies = {}
  956. for key, spec in self.input_signature.items():
  957. # 2 is the most correct arbitrary size. I will not be taking questions
  958. dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
  959. if spec.shape[0] is None:
  960. # But let's make the batch size 1 to save memory anyway
  961. dummy_shape[0] = 1
  962. dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
  963. if key == "token_type_ids":
  964. # Some models have token_type_ids but with a vocab_size of 1
  965. dummies[key] = tf.zeros_like(dummies[key])
  966. if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
  967. if "encoder_hidden_states" not in dummies:
  968. if self.main_input_name == "input_ids":
  969. dummies["encoder_hidden_states"] = tf.ones(
  970. shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
  971. )
  972. else:
  973. raise NotImplementedError(
  974. "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
  975. )
  976. return dummies
  977. def build_in_name_scope(self):
  978. with tf.name_scope(self.name):
  979. self.build(input_shape=None)
  980. @property
  981. def framework(self) -> str:
  982. """
  983. :str: Identifies that this is a TensorFlow model.
  984. """
  985. return "tf"
  986. def build(self, input_shape=None):
  987. pass # This is just here to make sure we don't call the superclass build()
  988. def __init__(self, config, *inputs, **kwargs):
  989. super().__init__(*inputs, **kwargs)
  990. if not isinstance(config, PretrainedConfig):
  991. raise TypeError(
  992. f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
  993. "`PretrainedConfig`. To create a model from a pretrained model use "
  994. f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
  995. )
  996. # Save config and origin of the pretrained weights if given in model
  997. self.config = config
  998. self.name_or_path = config.name_or_path
  999. self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
  1000. self._set_save_spec(self.input_signature)
  1001. def get_config(self):
  1002. return self.config.to_dict()
  1003. @functools.wraps(keras.Model.fit)
  1004. def fit(self, *args, **kwargs):
  1005. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1006. return super().fit(*args, **kwargs)
  1007. @functools.wraps(keras.Model.train_on_batch)
  1008. def train_on_batch(self, *args, **kwargs):
  1009. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1010. return super().train_on_batch(*args, **kwargs)
  1011. @functools.wraps(keras.Model.test_on_batch)
  1012. def test_on_batch(self, *args, **kwargs):
  1013. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1014. return super().test_on_batch(*args, **kwargs)
  1015. @functools.wraps(keras.Model.predict_on_batch)
  1016. def predict_on_batch(self, *args, **kwargs):
  1017. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1018. return super().predict_on_batch(*args, **kwargs)
  1019. @functools.wraps(keras.Model.predict)
  1020. def predict(self, *args, **kwargs):
  1021. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1022. return super().predict(*args, **kwargs)
  1023. @functools.wraps(keras.Model.evaluate)
  1024. def evaluate(self, *args, **kwargs):
  1025. args, kwargs = convert_batch_encoding(*args, **kwargs)
  1026. return super().evaluate(*args, **kwargs)
  1027. @classmethod
  1028. def from_config(cls, config, **kwargs):
  1029. if isinstance(config, PretrainedConfig):
  1030. return cls._from_config(config, **kwargs)
  1031. return cls._from_config(cls.config_class.from_dict(config, **kwargs))
  1032. @classmethod
  1033. def _from_config(cls, config, **kwargs):
  1034. """
  1035. All context managers that the model should be initialized under go here.
  1036. """
  1037. return cls(config, **kwargs)
  1038. def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor:
  1039. """
  1040. Prepare the head mask if needed.
  1041. Args:
  1042. head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  1043. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
  1044. num_hidden_layers (`int`):
  1045. The number of hidden layers in the model.
  1046. Returns:
  1047. `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
  1048. `[None]` for each layer.
  1049. """
  1050. if head_mask is not None:
  1051. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  1052. else:
  1053. head_mask = [None] * num_hidden_layers
  1054. return head_mask
  1055. def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
  1056. """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
  1057. if head_mask.shape.rank == 1:
  1058. head_mask = head_mask[None, None, :, None, None]
  1059. head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
  1060. elif head_mask.shape.rank == 2:
  1061. head_mask = head_mask[:, None, :, None, None]
  1062. assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
  1063. head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
  1064. return head_mask
  1065. @tf.function
  1066. def serving(self, inputs):
  1067. """
  1068. Args:
  1069. Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
  1070. functions when saving with `save_pretrained`.
  1071. inputs (`Dict[str, tf.Tensor]`):
  1072. The input of the saved model as a dictionary of tensors.
  1073. """
  1074. output = self.call(inputs)
  1075. return self.serving_output(output)
  1076. @property
  1077. def input_signature(self) -> Dict[str, tf.TensorSpec]:
  1078. """
  1079. This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
  1080. shape and dtype for model inputs. It is used for both serving and for generating dummy inputs.
  1081. """
  1082. model_inputs = list(inspect.signature(self.call).parameters)
  1083. sig = {}
  1084. if "input_ids" in model_inputs:
  1085. if self.__class__.__name__.endswith("ForMultipleChoice"):
  1086. text_dims = 3
  1087. else:
  1088. text_dims = 2
  1089. for input_name in (
  1090. "input_ids",
  1091. "attention_mask",
  1092. "token_type_ids",
  1093. "decoder_input_ids",
  1094. "decoder_attention_mask",
  1095. ):
  1096. if input_name in model_inputs:
  1097. sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
  1098. if "pixel_values" in model_inputs:
  1099. pixel_values_shape = [None, None, None, None]
  1100. if hasattr(self.config, "vision_config"):
  1101. vision_config = self.config.vision_config
  1102. else:
  1103. vision_config = self.config
  1104. if hasattr(vision_config, "num_channels"):
  1105. pixel_values_shape[1] = vision_config.num_channels
  1106. else:
  1107. raise NotImplementedError(
  1108. "Could not infer number of channels from config, please override input_signature to specify input shapes."
  1109. )
  1110. if hasattr(vision_config, "image_size"):
  1111. pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
  1112. elif hasattr(vision_config, "input_size"):
  1113. pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
  1114. else:
  1115. raise NotImplementedError(
  1116. "Could not infer input image shape from config, please override input_signature to specify input shapes."
  1117. )
  1118. sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
  1119. if "input_features" in model_inputs:
  1120. raise NotImplementedError("Audio models need a manually defined input_signature")
  1121. return sig
  1122. def serving_output(self, output):
  1123. """
  1124. Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
  1125. """
  1126. if not isinstance(output, ModelOutput):
  1127. return output
  1128. for key in output:
  1129. if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
  1130. output[key] = None
  1131. elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
  1132. output[key] = None
  1133. elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
  1134. output[key] = None
  1135. elif key == "cross_attentions" and not (
  1136. getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
  1137. ):
  1138. output[key] = None
  1139. if isinstance(output[key], (tuple, list)):
  1140. try:
  1141. output[key] = tf.convert_to_tensor(output[key])
  1142. except (ValueError, tf.errors.InvalidArgumentError):
  1143. pass # Layers may not have the same dimensions
  1144. return output
  1145. @classmethod
  1146. def can_generate(cls) -> bool:
  1147. """
  1148. Returns whether this model can generate sequences with `.generate()`.
  1149. Returns:
  1150. `bool`: Whether this model can generate sequences with `.generate()`.
  1151. """
  1152. # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
  1153. # Alternativelly, the model can also have a custom `generate` function.
  1154. if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
  1155. return False
  1156. return True
  1157. def get_input_embeddings(self) -> keras.layers.Layer:
  1158. """
  1159. Returns the model's input embeddings layer.
  1160. Returns:
  1161. `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
  1162. """
  1163. main_layer = getattr(self, self.base_model_prefix, self)
  1164. if main_layer is not self:
  1165. return main_layer.get_input_embeddings()
  1166. else:
  1167. raise NotImplementedError
  1168. def _save_checkpoint(self, checkpoint_dir, epoch):
  1169. if not os.path.isdir(checkpoint_dir):
  1170. os.mkdir(checkpoint_dir)
  1171. # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
  1172. # state for us, because it requires special handling for objects like custom losses, which we use
  1173. # internally and which users are likely to use too
  1174. weights_path = os.path.join(checkpoint_dir, "weights.h5")
  1175. self.save_weights(weights_path)
  1176. extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
  1177. extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
  1178. with open(extra_data_path, "wb") as f:
  1179. pickle.dump(extra_data, f)
  1180. def prepare_tf_dataset(
  1181. self,
  1182. dataset: "datasets.Dataset", # noqa:F821
  1183. batch_size: int = 8,
  1184. shuffle: bool = True,
  1185. tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  1186. collate_fn: Optional[Callable] = None,
  1187. collate_fn_args: Optional[Dict[str, Any]] = None,
  1188. drop_remainder: Optional[bool] = None,
  1189. prefetch: bool = True,
  1190. ):
  1191. """
  1192. Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is
  1193. designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
  1194. further modification. The method will drop columns from the dataset if they don't match input names for the
  1195. model. If you want to specify the column names to return rather than using the names that match this model, we
  1196. recommend using `Dataset.to_tf_dataset()` instead.
  1197. Args:
  1198. dataset (`Any`):
  1199. A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`.
  1200. batch_size (`int`, *optional*, defaults to 8):
  1201. The size of batches to return.
  1202. shuffle (`bool`, defaults to `True`):
  1203. Whether to return samples from the dataset in random order. Usually `True` for training datasets and
  1204. `False` for validation/test datasets.
  1205. tokenizer ([`PreTrainedTokenizerBase`], *optional*):
  1206. A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
  1207. `collate_fn` is passed instead.
  1208. collate_fn (`Callable`, *optional*):
  1209. A function that collates samples from the dataset into a single batch. Defaults to
  1210. `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
  1211. passed.
  1212. collate_fn_args (`Dict[str, Any]`, *optional*):
  1213. A dict of arguments to pass to the `collate_fn` alongside the list of samples.
  1214. drop_remainder (`bool`, *optional*):
  1215. Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
  1216. to the same setting as `shuffle`.
  1217. prefetch (`bool`, defaults to `True`):
  1218. Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
  1219. performance, but can be disabled in edge cases.
  1220. Returns:
  1221. `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
  1222. """
  1223. requires_backends(self, ["datasets"])
  1224. import datasets
  1225. if collate_fn is None:
  1226. if tokenizer is None:
  1227. collate_fn = DefaultDataCollator(return_tensors="np")
  1228. else:
  1229. collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
  1230. if collate_fn_args is None:
  1231. collate_fn_args = {}
  1232. if not isinstance(dataset, datasets.Dataset):
  1233. raise TypeError("Dataset argument should be a datasets.Dataset!")
  1234. model_inputs = list(inspect.signature(self.call).parameters)
  1235. model_labels = find_labels(self.__class__)
  1236. if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
  1237. output_signature, _ = dataset._get_output_signature(
  1238. dataset,
  1239. batch_size=None,
  1240. collate_fn=collate_fn,
  1241. collate_fn_args=collate_fn_args,
  1242. cols_to_retain=model_inputs,
  1243. )
  1244. else:
  1245. # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
  1246. # argument. We should remove this once the minimum supported version of datasets is > 2.3.2
  1247. unwanted_columns = [
  1248. feature
  1249. for feature in dataset.features
  1250. if feature not in model_inputs and feature not in ("label_ids", "label")
  1251. ]
  1252. dataset = dataset.remove_columns(unwanted_columns)
  1253. output_signature, _ = dataset._get_output_signature(
  1254. dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
  1255. )
  1256. output_columns = list(output_signature.keys())
  1257. feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
  1258. label_cols = [col for col in output_columns if col in model_labels]
  1259. # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols`
  1260. # were a single element list, the returned element spec would be a single element. Now, passing [feature]
  1261. # will return a dict structure {"feature": feature}, and passing a single string will return a single element.
  1262. feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols
  1263. label_cols = label_cols[0] if len(label_cols) == 1 else label_cols
  1264. if drop_remainder is None:
  1265. drop_remainder = shuffle
  1266. tf_dataset = dataset.to_tf_dataset(
  1267. columns=feature_cols,
  1268. label_cols=label_cols,
  1269. batch_size=batch_size,
  1270. shuffle=shuffle,
  1271. drop_remainder=drop_remainder,
  1272. collate_fn=collate_fn,
  1273. collate_fn_args=collate_fn_args,
  1274. prefetch=prefetch,
  1275. )
  1276. return tf_dataset
  1277. def compile(
  1278. self,
  1279. optimizer="rmsprop",
  1280. loss="auto_with_warning",
  1281. metrics=None,
  1282. loss_weights=None,
  1283. weighted_metrics=None,
  1284. run_eagerly=None,
  1285. steps_per_execution=None,
  1286. **kwargs,
  1287. ):
  1288. """
  1289. This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
  1290. function themselves.
  1291. """
  1292. if loss in ("auto_with_warning", "passthrough"): # "passthrough" for workflow backward compatibility
  1293. logger.info(
  1294. "No loss specified in compile() - the model's internal loss computation will be used as the "
  1295. "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
  1296. "To disable this behaviour please pass a loss argument, or explicitly pass "
  1297. "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
  1298. "get the internal loss without printing this info string."
  1299. )
  1300. loss = "auto"
  1301. if loss == "auto":
  1302. loss = dummy_loss
  1303. self._using_dummy_loss = True
  1304. else:
  1305. self._using_dummy_loss = False
  1306. parent_args = list(inspect.signature(keras.Model.compile).parameters.keys())
  1307. # This argument got renamed, we need to support both versions
  1308. if "steps_per_execution" in parent_args:
  1309. super().compile(
  1310. optimizer=optimizer,
  1311. loss=loss,
  1312. metrics=metrics,
  1313. loss_weights=loss_weights,
  1314. weighted_metrics=weighted_metrics,
  1315. run_eagerly=run_eagerly,
  1316. steps_per_execution=steps_per_execution,
  1317. **kwargs,
  1318. )
  1319. else:
  1320. super().compile(
  1321. optimizer=optimizer,
  1322. loss=loss,
  1323. metrics=metrics,
  1324. loss_weights=loss_weights,
  1325. weighted_metrics=weighted_metrics,
  1326. run_eagerly=run_eagerly,
  1327. experimental_steps_per_execution=steps_per_execution,
  1328. **kwargs,
  1329. )
  1330. def compute_loss(self, *args, **kwargs):
  1331. if hasattr(keras.Model, "compute_loss"):
  1332. # This will be true in TF 2.8 or greater
  1333. return super().compute_loss(*args, **kwargs)
  1334. else:
  1335. warnings.warn(
  1336. "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
  1337. "method added in TF 2.8. If you want the original HF compute_loss, please call "
  1338. "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
  1339. "calling compute_loss() will get the Keras method instead.",
  1340. FutureWarning,
  1341. )
  1342. return self.hf_compute_loss(*args, **kwargs)
  1343. def get_label_to_output_name_mapping(self):
  1344. arg_names = list(inspect.signature(self.call).parameters)
  1345. if self._label_to_output_map is not None:
  1346. return self._label_to_output_map
  1347. elif "start_positions" in arg_names:
  1348. return {"start_positions": "start_logits", "end_positions": "end_logits"}
  1349. elif "sentence_order_label" in arg_names:
  1350. return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
  1351. elif "next_sentence_label" in arg_names:
  1352. return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
  1353. elif "mc_labels" in arg_names:
  1354. return {"labels": "logits", "mc_labels": "mc_logits"}
  1355. else:
  1356. return {}
  1357. def train_step(self, data):
  1358. """
  1359. A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
  1360. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
  1361. labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
  1362. that they are available to the model during the forward pass.
  1363. """
  1364. # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
  1365. arg_names = list(inspect.signature(self.call).parameters)
  1366. label_kwargs = find_labels(self.__class__)
  1367. label_to_output = self.get_label_to_output_name_mapping()
  1368. output_to_label = {val: key for key, val in label_to_output.items()}
  1369. if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
  1370. # Newer TF train steps leave this out
  1371. data = expand_1d(data)
  1372. x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
  1373. # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
  1374. # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
  1375. # In addition, modifying mutable Python inputs makes XLA compilation impossible.
  1376. if isinstance(x, dict):
  1377. x = x.copy()
  1378. if isinstance(y, dict):
  1379. y = y.copy()
  1380. # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
  1381. # if those keys are not already present in the input dict
  1382. if self._using_dummy_loss and y is not None:
  1383. # If y is a tensor and the model only has one label-like input, map y to that input
  1384. if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
  1385. if isinstance(x, tf.Tensor):
  1386. x = {arg_names[0]: x}
  1387. label_kwarg = next(iter(label_kwargs))
  1388. if label_kwarg not in x:
  1389. x[label_kwarg] = y
  1390. # Otherwise, copy keys from y to x as long as they weren't already present in x
  1391. elif isinstance(y, dict):
  1392. if isinstance(x, tf.Tensor):
  1393. x = {arg_names[0]: x}
  1394. for key, val in y.items():
  1395. if key in arg_names and key not in x:
  1396. x[key] = val
  1397. elif output_to_label.get(key, None) in arg_names and key not in x:
  1398. x[output_to_label[key]] = val
  1399. if y is None:
  1400. y = {key: val for key, val in x.items() if key in label_kwargs}
  1401. if not y and not self._using_dummy_loss:
  1402. raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
  1403. if isinstance(y, dict):
  1404. # Rename labels at this point to match output heads
  1405. y = {label_to_output.get(key, key): val for key, val in y.items()}
  1406. # Run forward pass.
  1407. with tf.GradientTape() as tape:
  1408. if self._using_dummy_loss and "return_loss" in arg_names:
  1409. y_pred = self(x, training=True, return_loss=True)
  1410. else:
  1411. y_pred = self(x, training=True)
  1412. if self._using_dummy_loss:
  1413. loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
  1414. else:
  1415. loss = None
  1416. # This next block matches outputs to label keys. Tensorflow's standard method for doing this
  1417. # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
  1418. if isinstance(y, dict) and len(y) == 1:
  1419. if list(y.keys())[0] in y_pred.keys():
  1420. y_pred = y_pred[list(y.keys())[0]]
  1421. elif list(y_pred.keys())[0] == "loss":
  1422. y_pred = y_pred[1]
  1423. else:
  1424. y_pred = y_pred[0]
  1425. _, y = y.popitem()
  1426. elif isinstance(y, dict):
  1427. # If the labels are a dict, match keys from the output by name
  1428. y_pred = {key: val for key, val in y_pred.items() if key in y}
  1429. elif isinstance(y, tuple) or isinstance(y, list):
  1430. # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
  1431. if list(y_pred.keys())[0] == "loss":
  1432. y_pred = y_pred.to_tuple()[1:]
  1433. else:
  1434. y_pred = y_pred.to_tuple()
  1435. y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
  1436. else:
  1437. # If the labels are a single tensor, match them to the first non-loss tensor in the output
  1438. if list(y_pred.keys())[0] == "loss":
  1439. y_pred = y_pred[1]
  1440. else:
  1441. y_pred = y_pred[0]
  1442. if loss is None:
  1443. loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
  1444. # Run backwards pass.
  1445. self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
  1446. self.compiled_metrics.update_state(y, y_pred, sample_weight)
  1447. # Collect metrics to return
  1448. return_metrics = {}
  1449. for metric in self.metrics:
  1450. result = metric.result()
  1451. if isinstance(result, dict):
  1452. return_metrics.update(result)
  1453. else:
  1454. return_metrics[metric.name] = result
  1455. return return_metrics
  1456. def test_step(self, data):
  1457. """
  1458. A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
  1459. and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
  1460. labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
  1461. that they are available to the model during the forward pass.
  1462. """
  1463. # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
  1464. arg_names = list(inspect.signature(self.call).parameters)
  1465. label_kwargs = find_labels(self.__class__)
  1466. label_to_output = self.get_label_to_output_name_mapping()
  1467. output_to_label = {val: key for key, val in label_to_output.items()}
  1468. if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
  1469. # Newer versions leave this out
  1470. data = expand_1d(data)
  1471. x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
  1472. # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
  1473. # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
  1474. # In addition, modifying mutable Python inputs makes XLA compilation impossible.
  1475. if isinstance(x, dict):
  1476. x = x.copy()
  1477. if isinstance(y, dict):
  1478. y = y.copy()
  1479. # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
  1480. # if those keys are not already present in the input dict
  1481. if self._using_dummy_loss and y is not None:
  1482. arg_names = list(inspect.signature(self.call).parameters)
  1483. # If y is a tensor and the model only has one label-like input, map y to that input
  1484. if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
  1485. if isinstance(x, tf.Tensor):
  1486. x = {arg_names[0]: x}
  1487. label_kwarg = next(iter(label_kwargs))
  1488. if label_kwarg not in x:
  1489. x[label_kwarg] = y
  1490. # Otherwise, copy keys from y to x as long as they weren't already present in x
  1491. elif isinstance(y, dict):
  1492. if isinstance(x, tf.Tensor):
  1493. x = {arg_names[0]: x}
  1494. for key, val in y.items():
  1495. if key in arg_names and key not in x:
  1496. x[key] = val
  1497. elif output_to_label.get(key, None) in arg_names and key not in x:
  1498. x[output_to_label[key]] = val
  1499. if y is None:
  1500. y = {key: val for key, val in x.items() if key in label_kwargs}
  1501. if not y and not self._using_dummy_loss:
  1502. raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
  1503. if isinstance(y, dict):
  1504. # Rename labels at this point to match output heads
  1505. y = {label_to_output.get(key, key): val for key, val in y.items()}
  1506. # Run forward pass.
  1507. if self._using_dummy_loss and "return_loss" in arg_names:
  1508. y_pred = self(x, return_loss=True, training=False)
  1509. else:
  1510. y_pred = self(x, training=False)
  1511. if self._using_dummy_loss:
  1512. loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
  1513. else:
  1514. loss = None
  1515. # This next block matches outputs to label keys. Tensorflow's standard method for doing this
  1516. # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
  1517. if isinstance(y, dict) and len(y) == 1:
  1518. if list(y.keys())[0] in y_pred.keys():
  1519. y_pred = y_pred[list(y.keys())[0]]
  1520. elif list(y_pred.keys())[0] == "loss":
  1521. y_pred = y_pred[1]
  1522. else:
  1523. y_pred = y_pred[0]
  1524. _, y = y.popitem()
  1525. elif isinstance(y, dict):
  1526. # If the labels are a dict, match keys from the output by name
  1527. y_pred = {key: val for key, val in y_pred.items() if key in y}
  1528. elif isinstance(y, tuple) or isinstance(y, list):
  1529. # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
  1530. if list(y_pred.keys())[0] == "loss":
  1531. y_pred = y_pred.to_tuple()[1:]
  1532. else:
  1533. y_pred = y_pred.to_tuple()
  1534. y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
  1535. else:
  1536. # If the labels are a single tensor, match them to the first non-loss tensor in the output
  1537. if list(y_pred.keys())[0] == "loss":
  1538. y_pred = y_pred[1]
  1539. else:
  1540. y_pred = y_pred[0]
  1541. if loss is None:
  1542. loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
  1543. self.compiled_metrics.update_state(y, y_pred, sample_weight)
  1544. # Collect metrics to return
  1545. return_metrics = {}
  1546. for metric in self.metrics:
  1547. result = metric.result()
  1548. if isinstance(result, dict):
  1549. return_metrics.update(result)
  1550. else:
  1551. return_metrics[metric.name] = result
  1552. return return_metrics
  1553. def create_model_card(
  1554. self,
  1555. output_dir,
  1556. model_name: str,
  1557. language: Optional[str] = None,
  1558. license: Optional[str] = None,
  1559. tags: Optional[str] = None,
  1560. finetuned_from: Optional[str] = None,
  1561. tasks: Optional[str] = None,
  1562. dataset_tags: Optional[Union[str, List[str]]] = None,
  1563. dataset: Optional[Union[str, List[str]]] = None,
  1564. dataset_args: Optional[Union[str, List[str]]] = None,
  1565. ):
  1566. """
  1567. Creates a draft of a model card using the information available to the `Trainer`.
  1568. Args:
  1569. output_dir (`str` or `os.PathLike`):
  1570. The folder in which to create the model card.
  1571. model_name (`str`, *optional*):
  1572. The name of the model.
  1573. language (`str`, *optional*):
  1574. The language of the model (if applicable)
  1575. license (`str`, *optional*):
  1576. The license of the model. Will default to the license of the pretrained model used, if the original
  1577. model given to the `Trainer` comes from a repo on the Hub.
  1578. tags (`str` or `List[str]`, *optional*):
  1579. Some tags to be included in the metadata of the model card.
  1580. finetuned_from (`str`, *optional*):
  1581. The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
  1582. of the original model given to the `Trainer` (if it comes from the Hub).
  1583. tasks (`str` or `List[str]`, *optional*):
  1584. One or several task identifiers, to be included in the metadata of the model card.
  1585. dataset_tags (`str` or `List[str]`, *optional*):
  1586. One or several dataset tags, to be included in the metadata of the model card.
  1587. dataset (`str` or `List[str]`, *optional*):
  1588. One or several dataset identifiers, to be included in the metadata of the model card.
  1589. dataset_args (`str` or `List[str]`, *optional*):
  1590. One or several dataset arguments, to be included in the metadata of the model card.
  1591. """
  1592. # Avoids a circular import by doing this when necessary.
  1593. from .modelcard import TrainingSummary # tests_ignore
  1594. training_summary = TrainingSummary.from_keras(
  1595. self,
  1596. keras_history=self.history,
  1597. language=language,
  1598. license=license,
  1599. tags=tags,
  1600. model_name=model_name,
  1601. finetuned_from=finetuned_from,
  1602. tasks=tasks,
  1603. dataset_tags=dataset_tags,
  1604. dataset=dataset,
  1605. dataset_args=dataset_args,
  1606. )
  1607. model_card = training_summary.to_model_card()
  1608. with open(os.path.join(output_dir, "README.md"), "w") as f:
  1609. f.write(model_card)
  1610. def set_input_embeddings(self, value):
  1611. """
  1612. Set model's input embeddings
  1613. Args:
  1614. value (`tf.Variable`):
  1615. The new weights mapping hidden states to vocabulary.
  1616. """
  1617. main_layer = getattr(self, self.base_model_prefix)
  1618. if main_layer is None:
  1619. raise NotImplementedError("The model does not implements the base_model_prefix attribute.")
  1620. try:
  1621. main_layer.set_input_embeddings(value)
  1622. except AttributeError:
  1623. logger.info("Building the model")
  1624. self.build_in_name_scope()
  1625. main_layer.set_input_embeddings(value)
  1626. def get_output_embeddings(self) -> Union[None, keras.layers.Layer]:
  1627. """
  1628. Returns the model's output embeddings
  1629. Returns:
  1630. `tf.Variable`: The new weights mapping vocabulary to hidden states.
  1631. """
  1632. if self.get_lm_head() is not None:
  1633. lm_head = self.get_lm_head()
  1634. try:
  1635. return lm_head.get_output_embeddings()
  1636. except AttributeError:
  1637. logger.info("Building the model")
  1638. self.build_in_name_scope()
  1639. return lm_head().get_output_embeddings()
  1640. return None # Overwrite for models with output embeddings
  1641. def set_output_embeddings(self, value):
  1642. """
  1643. Set model's output embeddings
  1644. Args:
  1645. value (`tf.Variable`):
  1646. The new weights mapping hidden states to vocabulary.
  1647. """
  1648. if self.get_lm_head() is not None:
  1649. lm_head = self.get_lm_head()
  1650. try:
  1651. lm_head.set_output_embeddings(value)
  1652. except AttributeError:
  1653. logger.info("Building the model")
  1654. self.build_in_name_scope()
  1655. lm_head.set_output_embeddings(value)
  1656. def get_output_layer_with_bias(self) -> Union[None, keras.layers.Layer]:
  1657. """
  1658. Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
  1659. embeddings
  1660. Return:
  1661. `keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
  1662. """
  1663. warnings.warn(
  1664. "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
  1665. )
  1666. return self.get_lm_head()
  1667. def get_prefix_bias_name(self) -> Union[None, str]:
  1668. """
  1669. Get the concatenated _prefix name of the bias from the model name to the parent layer
  1670. Return:
  1671. `str`: The _prefix name of the bias.
  1672. """
  1673. warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
  1674. return None
  1675. def get_bias(self) -> Union[None, Dict[str, tf.Variable]]:
  1676. """
  1677. Dict of bias attached to an LM head. The key represents the name of the bias attribute.
  1678. Return:
  1679. `tf.Variable`: The weights representing the bias, None if not an LM model.
  1680. """
  1681. if self.get_lm_head() is not None:
  1682. lm_head = self.get_lm_head()
  1683. try:
  1684. return lm_head.get_bias()
  1685. except AttributeError:
  1686. self.build_in_name_scope()
  1687. return lm_head.get_bias()
  1688. return None
  1689. def set_bias(self, value):
  1690. """
  1691. Set all the bias in the LM head.
  1692. Args:
  1693. value (`Dict[tf.Variable]`):
  1694. All the new bias attached to an LM head.
  1695. """
  1696. if self.get_lm_head() is not None:
  1697. lm_head = self.get_lm_head()
  1698. try:
  1699. lm_head.set_bias(value)
  1700. except AttributeError:
  1701. self.build_in_name_scope()
  1702. lm_head.set_bias(value)
  1703. def get_lm_head(self) -> keras.layers.Layer:
  1704. """
  1705. The LM Head layer. This method must be overwritten by all the models that have a lm head.
  1706. Return:
  1707. `keras.layers.Layer`: The LM head layer if the model has one, None if not.
  1708. """
  1709. return None
  1710. def resize_token_embeddings(
  1711. self, new_num_tokens: Optional[int] = None
  1712. ) -> Union[keras.layers.Embedding, tf.Variable]:
  1713. """
  1714. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  1715. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  1716. Arguments:
  1717. new_num_tokens (`int`, *optional*):
  1718. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  1719. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1720. returns a pointer to the input tokens without doing anything.
  1721. Return:
  1722. `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model.
  1723. """
  1724. # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
  1725. # Run the new code path if the model has a keras embeddings layer
  1726. if isinstance(self.get_input_embeddings(), keras.layers.Embedding):
  1727. return self._v2_resized_token_embeddings(new_num_tokens)
  1728. if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
  1729. return self._get_word_embedding_weight(self.get_input_embeddings())
  1730. model_embeds = self._resize_token_embeddings(new_num_tokens)
  1731. # Update base model and current model config
  1732. self.config.vocab_size = new_num_tokens
  1733. return model_embeds
  1734. def _v2_resized_token_embeddings(self, new_num_tokens: Optional[int] = None) -> keras.layers.Embedding:
  1735. """
  1736. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  1737. Arguments:
  1738. new_num_tokens (`int`, *optional*):
  1739. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  1740. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1741. returns a pointer to the input tokens without doing anything.
  1742. Return:
  1743. `keras.layers.Embedding`: Pointer to the input tokens of the model.
  1744. """
  1745. if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
  1746. return self.get_input_embeddings()
  1747. model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
  1748. # Update base model and current model config
  1749. self.config.vocab_size = new_num_tokens
  1750. return model_embeds
  1751. def _get_word_embedding_weight(model, embedding_layer):
  1752. # TODO (joao): flagged for delection due to embeddings refactor
  1753. # If the variable holds the weights themselves, return them
  1754. if isinstance(embedding_layer, tf.Tensor):
  1755. return embedding_layer
  1756. # Otherwise, try to get them from the layer's attributes
  1757. embeds = getattr(embedding_layer, "weight", None)
  1758. if embeds is not None:
  1759. return embeds
  1760. embeds = getattr(embedding_layer, "decoder", None)
  1761. if embeds is not None:
  1762. return embeds
  1763. # The reason why the attributes don't exist might be
  1764. # because the model is not built, so retry getting
  1765. # the argument after building the model
  1766. model.build_in_name_scope()
  1767. embeds = getattr(embedding_layer, "weight", None)
  1768. if embeds is not None:
  1769. return embeds
  1770. embeds = getattr(embedding_layer, "decoder", None)
  1771. if embeds is not None:
  1772. return embeds
  1773. return None
  1774. def _resize_token_embeddings(self, new_num_tokens):
  1775. # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
  1776. old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
  1777. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  1778. # if word embeddings are not tied, make sure that lm head bias is resized as well
  1779. if self.get_bias() is not None:
  1780. old_lm_head_bias = self.get_bias()
  1781. new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
  1782. self.set_bias(new_lm_head_bias)
  1783. # if word embeddings are not tied, make sure that lm head decoder is resized as well
  1784. if self.get_output_embeddings() is not None:
  1785. old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
  1786. new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
  1787. self.set_output_embeddings(new_lm_head_decoder)
  1788. self.set_input_embeddings(new_embeddings)
  1789. return self.get_input_embeddings()
  1790. def _v2_resize_token_embeddings(self, new_num_tokens):
  1791. old_embeddings = self.get_input_embeddings()
  1792. new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
  1793. self.set_input_embeddings(new_embeddings)
  1794. # If word embeddings are not tied, make sure that lm head bias is resized as well
  1795. if self.get_bias() is not None:
  1796. old_lm_head_bias = self.get_bias()
  1797. new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
  1798. self.set_bias(new_lm_head_bias)
  1799. # If word embeddings are not tied, make sure that lm head decoder is resized as well.
  1800. tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
  1801. if self.get_output_embeddings() is not None and not tied_weights:
  1802. old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
  1803. # TODO (joao): this one probably needs a v2 version with other models
  1804. new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
  1805. self.set_output_embeddings(new_lm_head_decoder)
  1806. return self.get_input_embeddings()
  1807. def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
  1808. """
  1809. Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
  1810. Reducing the size will remove vectors from the end
  1811. Args:
  1812. old_lm_head_bias (`tf.Variable`):
  1813. Old lm head bias to be resized.
  1814. new_num_tokens (`int`, *optional*):
  1815. New number of tokens in the linear matrix.
  1816. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1817. vectors from the end. If not provided or `None`, just returns None
  1818. Return:
  1819. `tf.Variable`: Pointer to the resized bias.
  1820. """
  1821. # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor
  1822. new_lm_head_bias = {}
  1823. for attr, weight in old_lm_head_bias.items():
  1824. first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
  1825. size_diff = new_num_tokens - old_num_tokens
  1826. final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]
  1827. # initialize new bias
  1828. if tf.math.greater(size_diff, 0):
  1829. padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
  1830. current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
  1831. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  1832. mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
  1833. bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
  1834. bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
  1835. else:
  1836. slice_from = [0] if first_dim is None else [0, 0]
  1837. current_bias = tf.slice(
  1838. weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
  1839. )
  1840. bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
  1841. new_bias = self.add_weight(
  1842. shape=final_shape,
  1843. initializer="zeros",
  1844. trainable=True,
  1845. name=weight.name.split(":")[0],
  1846. )
  1847. init_bias = tf.where(bias_mask, current_bias, new_bias.value())
  1848. new_bias.assign(init_bias)
  1849. new_lm_head_bias[attr] = new_bias
  1850. return new_lm_head_bias
  1851. def _v2_get_resized_lm_head_bias(
  1852. self, old_lm_head_bias: Dict[str, tf.Variable], new_num_tokens: int
  1853. ) -> Dict[str, tf.Tensor]:
  1854. """
  1855. Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
  1856. Reducing the size will remove vectors from the end
  1857. Args:
  1858. old_lm_head_bias (`Dict[str, tf.Variable]`):
  1859. Old lm head bias to be resized.
  1860. new_num_tokens (`int`):
  1861. New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at
  1862. the end. Reducing the size will remove vectors from the end.
  1863. Return:
  1864. `tf.Tensor`: Values for the resized bias.
  1865. """
  1866. new_lm_head_bias = {}
  1867. for attr, weight in old_lm_head_bias.items():
  1868. # Determine the size difference (depending on the shape)
  1869. first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
  1870. size_diff = new_num_tokens - old_num_tokens
  1871. # Copy the old bias values to the new bias
  1872. if old_num_tokens > new_num_tokens:
  1873. new_bias = weight.value()[..., :new_num_tokens]
  1874. else:
  1875. padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
  1876. new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape))
  1877. new_lm_head_bias[attr] = new_bias
  1878. return new_lm_head_bias
  1879. def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
  1880. """
  1881. Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
  1882. Reducing the size will remove vectors from the end
  1883. Args:
  1884. old_lm_head_decoder (`tf.Variable`):
  1885. Old lm head decoder to be resized.
  1886. new_num_tokens (`int`, *optional*):
  1887. New number of tokens in the linear matrix.
  1888. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1889. vectors from the end. If not provided or `None`, just returns None
  1890. Return:
  1891. `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
  1892. ones.
  1893. """
  1894. new_lm_head_decoder = old_lm_head_decoder
  1895. is_input_output_equals = tf.reduce_any(
  1896. self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
  1897. )
  1898. if old_lm_head_decoder is not None and not is_input_output_equals:
  1899. old_embedding_dim = shape_list(old_lm_head_decoder)[1]
  1900. decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
  1901. new_lm_head_decoder = self.add_weight(
  1902. shape=(new_num_tokens, old_embedding_dim),
  1903. initializer="zeros",
  1904. trainable=True,
  1905. name=old_lm_head_decoder.name.split(":")[0],
  1906. )
  1907. init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())
  1908. new_lm_head_decoder.assign(init_decoder)
  1909. return new_lm_head_decoder
  1910. def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
  1911. """
  1912. Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
  1913. initialized vectors at the end. Reducing the size will remove vectors from the end
  1914. Args:
  1915. old_embeddings (`tf.Variable`):
  1916. Old embeddings to be resized.
  1917. new_num_tokens (`int`, *optional*):
  1918. New number of tokens in the embedding matrix.
  1919. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  1920. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  1921. `tf.Variable` module of the model without doing anything.
  1922. Return:
  1923. `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
  1924. `None`
  1925. """
  1926. # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
  1927. old_embedding_dim = shape_list(old_embeddings)[1]
  1928. init_range = getattr(self.config, "initializer_range", 0.02)
  1929. embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
  1930. new_embeddings = self.add_weight(
  1931. name=old_embeddings.name.split(":")[0],
  1932. shape=[new_num_tokens, old_embedding_dim],
  1933. initializer=get_initializer(init_range),
  1934. dtype=tf.float32,
  1935. )
  1936. init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
  1937. new_embeddings.assign(init_embeddings)
  1938. return new_embeddings
  1939. def _v2_get_resized_embeddings(
  1940. self, old_embeddings: keras.layers.Embedding, new_num_tokens: int
  1941. ) -> keras.layers.Embedding:
  1942. """
  1943. Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
  1944. vectors at the end. Reducing the size will remove vectors from the end.
  1945. Args:
  1946. old_embeddings (`keras.layers.Embedding`):
  1947. Old embeddings to be resized.
  1948. new_num_tokens (`int`, *optional*):
  1949. New number of tokens in the embedding matrix.
  1950. Return:
  1951. `keras.layers.Embedding`: Resized Embedding layer.
  1952. """
  1953. # Get the initialization range for the embeddings
  1954. init_range = 0.02 # default value
  1955. potential_initialization_variable_names = [
  1956. "initializer_range", # most common
  1957. "initializer_factor", # e.g. T5
  1958. "init_std", # e.g BART
  1959. ]
  1960. for var_name in potential_initialization_variable_names:
  1961. if hasattr(self.config, var_name):
  1962. init_range = getattr(self.config, var_name)
  1963. # Get a new (initialized) embeddings layer
  1964. new_embeddings = keras.layers.Embedding(
  1965. input_dim=new_num_tokens,
  1966. output_dim=old_embeddings.output_dim,
  1967. embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range),
  1968. name=old_embeddings.embeddings.name[:-13], # exact same scoped name except "/embeddings:0"
  1969. )
  1970. new_embeddings(tf.constant([[0]]))
  1971. # Copy the old embeddings to the new embeddings
  1972. if old_embeddings.input_dim >= new_num_tokens:
  1973. init_embeddings = old_embeddings.embeddings[:new_num_tokens]
  1974. else:
  1975. init_embeddings = tf.concat(
  1976. [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
  1977. )
  1978. new_embeddings.embeddings.assign(init_embeddings)
  1979. return new_embeddings
  1980. def prune_heads(self, heads_to_prune):
  1981. """
  1982. Prunes heads of the base model.
  1983. Arguments:
  1984. heads_to_prune (`Dict[int, List[int]]`):
  1985. Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
  1986. to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
  1987. layer 1 and heads 2 and 3 on layer 2.
  1988. """
  1989. raise NotImplementedError
  1990. def save_pretrained(
  1991. self,
  1992. save_directory,
  1993. saved_model=False,
  1994. version=1,
  1995. push_to_hub=False,
  1996. signatures=None,
  1997. max_shard_size: Union[int, str] = "5GB",
  1998. create_pr: bool = False,
  1999. safe_serialization: bool = False,
  2000. token: Optional[Union[str, bool]] = None,
  2001. **kwargs,
  2002. ):
  2003. """
  2004. Save a model and its configuration file to a directory, so that it can be re-loaded using the
  2005. [`~TFPreTrainedModel.from_pretrained`] class method.
  2006. Arguments:
  2007. save_directory (`str`):
  2008. Directory to which to save. Will be created if it doesn't exist.
  2009. saved_model (`bool`, *optional*, defaults to `False`):
  2010. If the model has to be saved in saved model format as well or not.
  2011. version (`int`, *optional*, defaults to 1):
  2012. The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
  2013. TensorFlow Serving as detailed in the official documentation
  2014. https://www.tensorflow.org/tfx/serving/serving_basic
  2015. push_to_hub (`bool`, *optional*, defaults to `False`):
  2016. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  2017. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  2018. namespace).
  2019. signatures (`dict` or `tf.function`, *optional*):
  2020. Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
  2021. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  2022. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  2023. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  2024. <Tip warning={true}>
  2025. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  2026. which will be bigger than `max_shard_size`.
  2027. </Tip>
  2028. create_pr (`bool`, *optional*, defaults to `False`):
  2029. Whether or not to create a PR with the uploaded files or directly commit.
  2030. safe_serialization (`bool`, *optional*, defaults to `False`):
  2031. Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
  2032. token (`str` or `bool`, *optional*):
  2033. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2034. the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
  2035. kwargs (`Dict[str, Any]`, *optional*):
  2036. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  2037. """
  2038. use_auth_token = kwargs.pop("use_auth_token", None)
  2039. if use_auth_token is not None:
  2040. warnings.warn(
  2041. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2042. FutureWarning,
  2043. )
  2044. if token is not None:
  2045. raise ValueError(
  2046. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2047. )
  2048. token = use_auth_token
  2049. if token is not None:
  2050. kwargs["token"] = token
  2051. if os.path.isfile(save_directory):
  2052. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  2053. return
  2054. os.makedirs(save_directory, exist_ok=True)
  2055. if push_to_hub:
  2056. commit_message = kwargs.pop("commit_message", None)
  2057. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  2058. repo_id = self._create_repo(repo_id, **kwargs)
  2059. files_timestamps = self._get_files_timestamps(save_directory)
  2060. if saved_model:
  2061. # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string.
  2062. # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.)
  2063. if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
  2064. self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
  2065. if signatures is None:
  2066. serving_default = self.serving.get_concrete_function(self.input_signature)
  2067. if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
  2068. int64_spec = {
  2069. key: tf.TensorSpec(
  2070. shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
  2071. )
  2072. for key, spec in self.input_signature.items()
  2073. }
  2074. int64_serving = self.serving.get_concrete_function(int64_spec)
  2075. signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
  2076. else:
  2077. signatures = serving_default
  2078. saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
  2079. self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
  2080. logger.info(f"Saved model created in {saved_model_dir}")
  2081. # Save configuration file
  2082. self.config.architectures = [self.__class__.__name__[2:]]
  2083. # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
  2084. # loaded from the Hub.
  2085. if self._auto_class is not None:
  2086. custom_object_save(self, save_directory, config=self.config)
  2087. self.config.save_pretrained(save_directory)
  2088. if self.can_generate():
  2089. self.generation_config.save_pretrained(save_directory)
  2090. # If we save using the predefined names, we can load using `from_pretrained`
  2091. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
  2092. output_model_file = os.path.join(save_directory, weights_name)
  2093. shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
  2094. # Clean the folder from a previous save
  2095. for filename in os.listdir(save_directory):
  2096. full_filename = os.path.join(save_directory, filename)
  2097. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  2098. # in distributed settings to avoid race conditions.
  2099. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
  2100. if (
  2101. filename.startswith(weights_no_suffix)
  2102. and os.path.isfile(full_filename)
  2103. and filename not in shards.keys()
  2104. ):
  2105. os.remove(full_filename)
  2106. if index is None:
  2107. if safe_serialization:
  2108. state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
  2109. safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
  2110. else:
  2111. self.save_weights(output_model_file)
  2112. logger.info(f"Model weights saved in {output_model_file}")
  2113. else:
  2114. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
  2115. save_index_file = os.path.join(save_directory, save_index_file)
  2116. # Save the index as well
  2117. with open(save_index_file, "w", encoding="utf-8") as index_file:
  2118. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  2119. index_file.write(content)
  2120. logger.info(
  2121. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  2122. f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
  2123. f"index located at {save_index_file}."
  2124. )
  2125. for shard_file, shard in shards.items():
  2126. if safe_serialization:
  2127. shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
  2128. safe_save_file(
  2129. shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
  2130. )
  2131. else:
  2132. with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
  2133. layers = []
  2134. for layer in sorted(shard, key=lambda x: x.name):
  2135. if "model." in layer.name or len(layer.name.split("/")) == 1:
  2136. layer_name = layer.name
  2137. else:
  2138. layer_name = "/".join(layer.name.split("/")[1:])
  2139. param_dset = shard_file.create_dataset(
  2140. layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
  2141. )
  2142. param_dset[:] = layer.numpy()
  2143. layers.append(layer_name.encode("utf8"))
  2144. save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
  2145. if push_to_hub:
  2146. self._upload_modified_files(
  2147. save_directory,
  2148. repo_id,
  2149. files_timestamps,
  2150. commit_message=commit_message,
  2151. token=token,
  2152. )
  2153. @classmethod
  2154. def from_pretrained(
  2155. cls,
  2156. pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
  2157. *model_args,
  2158. config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
  2159. cache_dir: Optional[Union[str, os.PathLike]] = None,
  2160. ignore_mismatched_sizes: bool = False,
  2161. force_download: bool = False,
  2162. local_files_only: bool = False,
  2163. token: Optional[Union[str, bool]] = None,
  2164. revision: str = "main",
  2165. use_safetensors: bool = None,
  2166. **kwargs,
  2167. ):
  2168. r"""
  2169. Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
  2170. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
  2171. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
  2172. task.
  2173. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
  2174. weights are discarded.
  2175. Parameters:
  2176. pretrained_model_name_or_path (`str`, *optional*):
  2177. Can be either:
  2178. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  2179. - A path to a *directory* containing model weights saved using
  2180. [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  2181. - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
  2182. case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
  2183. argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
  2184. using the provided conversion scripts and loading the TensorFlow model afterwards.
  2185. - `None` if you are both providing the configuration and state dictionary (resp. with keyword
  2186. arguments `config` and `state_dict`).
  2187. model_args (sequence of positional arguments, *optional*):
  2188. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  2189. config (`Union[PretrainedConfig, str]`, *optional*):
  2190. Can be either:
  2191. - an instance of a class derived from [`PretrainedConfig`],
  2192. - a string valid as input to [`~PretrainedConfig.from_pretrained`].
  2193. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  2194. be automatically loaded when:
  2195. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  2196. model).
  2197. - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the
  2198. save directory.
  2199. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  2200. configuration JSON file named *config.json* is found in the directory.
  2201. from_pt (`bool`, *optional*, defaults to `False`):
  2202. Load the model weights from a PyTorch state_dict save file (see docstring of
  2203. `pretrained_model_name_or_path` argument).
  2204. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  2205. Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
  2206. as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
  2207. checkpoint with 3 labels).
  2208. cache_dir (`str`, *optional*):
  2209. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  2210. standard cache should not be used.
  2211. force_download (`bool`, *optional*, defaults to `False`):
  2212. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  2213. cached versions if they exist.
  2214. resume_download:
  2215. Deprecated and ignored. All downloads are now resumed by default when possible.
  2216. Will be removed in v5 of Transformers.
  2217. proxies:
  2218. (`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g.,
  2219. `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  2220. output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a
  2221. dictionary containing missing keys, unexpected keys and error messages.
  2222. local_files_only(`bool`, *optional*, defaults to `False`):
  2223. Whether or not to only look at local files (e.g., not try downloading the model).
  2224. token (`str` or `bool`, *optional*):
  2225. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  2226. the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
  2227. revision (`str`, *optional*, defaults to `"main"`):
  2228. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  2229. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  2230. identifier allowed by git.
  2231. <Tip>
  2232. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  2233. </Tip>
  2234. mirror (`str`, *optional*):
  2235. Mirror source to accelerate downloads in China. If you are from China and have an accessibility
  2236. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
  2237. Please refer to the mirror site for more information.
  2238. subfolder (`str`, *optional*, defaults to `""`):
  2239. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  2240. specify the folder name here.
  2241. tf_to_pt_weight_rename (`Callable`, *optional*):
  2242. A function that is called to transform the names of weights during the PyTorch to TensorFlow
  2243. crossloading process. This is not necessary for most models, but is useful to allow composite models to
  2244. be crossloaded correctly.
  2245. use_safetensors (`bool`, *optional*, defaults to `None`):
  2246. Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
  2247. is not installed, it will be set to `False`.
  2248. kwargs (remaining dictionary of keyword arguments, *optional*):
  2249. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  2250. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  2251. automatically loaded:
  2252. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  2253. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  2254. already been done)
  2255. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  2256. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  2257. corresponds to a configuration attribute will be used to override said attribute with the
  2258. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  2259. will be passed to the underlying model's `__init__` function.
  2260. Examples:
  2261. ```python
  2262. >>> from transformers import BertConfig, TFBertModel
  2263. >>> # Download model and configuration from huggingface.co and cache.
  2264. >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased")
  2265. >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
  2266. >>> model = TFBertModel.from_pretrained("./test/saved_model/")
  2267. >>> # Update configuration during loading.
  2268. >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
  2269. >>> assert model.config.output_attentions == True
  2270. >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
  2271. >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
  2272. >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
  2273. ```"""
  2274. from_pt = kwargs.pop("from_pt", False)
  2275. resume_download = kwargs.pop("resume_download", None)
  2276. proxies = kwargs.pop("proxies", None)
  2277. output_loading_info = kwargs.pop("output_loading_info", False)
  2278. use_auth_token = kwargs.pop("use_auth_token", None)
  2279. trust_remote_code = kwargs.pop("trust_remote_code", None)
  2280. _ = kwargs.pop("mirror", None)
  2281. load_weight_prefix = kwargs.pop("load_weight_prefix", None)
  2282. from_pipeline = kwargs.pop("_from_pipeline", None)
  2283. from_auto_class = kwargs.pop("_from_auto", False)
  2284. subfolder = kwargs.pop("subfolder", "")
  2285. commit_hash = kwargs.pop("_commit_hash", None)
  2286. tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
  2287. # Not relevant for TF models
  2288. _ = kwargs.pop("adapter_kwargs", None)
  2289. if use_auth_token is not None:
  2290. warnings.warn(
  2291. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2292. FutureWarning,
  2293. )
  2294. if token is not None:
  2295. raise ValueError(
  2296. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2297. )
  2298. token = use_auth_token
  2299. if trust_remote_code is True:
  2300. logger.warning(
  2301. "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
  2302. " ignored."
  2303. )
  2304. user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
  2305. if from_pipeline is not None:
  2306. user_agent["using_pipeline"] = from_pipeline
  2307. if is_offline_mode() and not local_files_only:
  2308. logger.info("Offline mode: forcing local_files_only=True")
  2309. local_files_only = True
  2310. if use_safetensors is None and not is_safetensors_available():
  2311. use_safetensors = False
  2312. # Load config if we don't provide a configuration
  2313. if not isinstance(config, PretrainedConfig):
  2314. config_path = config if config is not None else pretrained_model_name_or_path
  2315. config, model_kwargs = cls.config_class.from_pretrained(
  2316. config_path,
  2317. cache_dir=cache_dir,
  2318. return_unused_kwargs=True,
  2319. force_download=force_download,
  2320. resume_download=resume_download,
  2321. proxies=proxies,
  2322. local_files_only=local_files_only,
  2323. token=token,
  2324. revision=revision,
  2325. _from_auto=from_auto_class,
  2326. _from_pipeline=from_pipeline,
  2327. _commit_hash=commit_hash,
  2328. **kwargs,
  2329. )
  2330. else:
  2331. model_kwargs = kwargs
  2332. if commit_hash is None:
  2333. commit_hash = getattr(config, "_commit_hash", None)
  2334. # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
  2335. # index of the files.
  2336. is_sharded = False
  2337. # Load model
  2338. if pretrained_model_name_or_path is not None:
  2339. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  2340. is_local = os.path.isdir(pretrained_model_name_or_path)
  2341. if is_local:
  2342. if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
  2343. # Load from a PyTorch checkpoint in priority if from_pt
  2344. archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
  2345. elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
  2346. # Load from a sharded PyTorch checkpoint
  2347. archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
  2348. is_sharded = True
  2349. elif use_safetensors is not False and os.path.isfile(
  2350. os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
  2351. ):
  2352. # Load from a safetensors checkpoint
  2353. archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
  2354. elif use_safetensors is not False and os.path.isfile(
  2355. os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
  2356. ):
  2357. # Load from a sharded safetensors checkpoint
  2358. archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
  2359. is_sharded = True
  2360. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
  2361. # Load from a TF 2.0 checkpoint
  2362. archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
  2363. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
  2364. # Load from a sharded TF 2.0 checkpoint
  2365. archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
  2366. is_sharded = True
  2367. # At this stage we don't have a weight file so we will raise an error.
  2368. elif use_safetensors:
  2369. raise EnvironmentError(
  2370. f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
  2371. f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
  2372. f"set `use_safetensors=True`."
  2373. )
  2374. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
  2375. os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
  2376. ):
  2377. raise EnvironmentError(
  2378. f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
  2379. "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
  2380. "weights."
  2381. )
  2382. else:
  2383. raise EnvironmentError(
  2384. f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
  2385. f"{pretrained_model_name_or_path}."
  2386. )
  2387. elif os.path.isfile(pretrained_model_name_or_path):
  2388. archive_file = pretrained_model_name_or_path
  2389. is_local = True
  2390. elif os.path.isfile(pretrained_model_name_or_path + ".index"):
  2391. archive_file = pretrained_model_name_or_path + ".index"
  2392. is_local = True
  2393. elif is_remote_url(pretrained_model_name_or_path):
  2394. filename = pretrained_model_name_or_path
  2395. resolved_archive_file = download_url(pretrained_model_name_or_path)
  2396. else:
  2397. # set correct filename
  2398. if from_pt:
  2399. filename = WEIGHTS_NAME
  2400. elif use_safetensors is not False:
  2401. filename = SAFE_WEIGHTS_NAME
  2402. else:
  2403. filename = TF2_WEIGHTS_NAME
  2404. try:
  2405. # Load from URL or cache if already cached
  2406. cached_file_kwargs = {
  2407. "cache_dir": cache_dir,
  2408. "force_download": force_download,
  2409. "proxies": proxies,
  2410. "resume_download": resume_download,
  2411. "local_files_only": local_files_only,
  2412. "token": token,
  2413. "user_agent": user_agent,
  2414. "revision": revision,
  2415. "subfolder": subfolder,
  2416. "_raise_exceptions_for_gated_repo": False,
  2417. "_raise_exceptions_for_missing_entries": False,
  2418. "_commit_hash": commit_hash,
  2419. }
  2420. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  2421. # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
  2422. # result when internet is up, the repo and revision exist, but the file does not.
  2423. if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
  2424. # Did not find the safetensors file, let's fallback to TF.
  2425. # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
  2426. filename = TF2_WEIGHTS_NAME
  2427. resolved_archive_file = cached_file(
  2428. pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
  2429. )
  2430. if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
  2431. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  2432. resolved_archive_file = cached_file(
  2433. pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
  2434. )
  2435. if resolved_archive_file is not None:
  2436. is_sharded = True
  2437. if resolved_archive_file is None and filename == WEIGHTS_NAME:
  2438. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  2439. resolved_archive_file = cached_file(
  2440. pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
  2441. )
  2442. if resolved_archive_file is not None:
  2443. is_sharded = True
  2444. if resolved_archive_file is None:
  2445. # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
  2446. # message.
  2447. has_file_kwargs = {
  2448. "revision": revision,
  2449. "proxies": proxies,
  2450. "token": token,
  2451. "cache_dir": cache_dir,
  2452. "local_files_only": local_files_only,
  2453. }
  2454. if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
  2455. is_sharded = True
  2456. elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
  2457. raise EnvironmentError(
  2458. f"{pretrained_model_name_or_path} does not appear to have a file named"
  2459. f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
  2460. " load this model from those weights."
  2461. )
  2462. else:
  2463. raise EnvironmentError(
  2464. f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
  2465. f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
  2466. )
  2467. except EnvironmentError:
  2468. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  2469. # to the original exception.
  2470. raise
  2471. except Exception:
  2472. # For any other exception, we throw a generic error.
  2473. raise EnvironmentError(
  2474. f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
  2475. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  2476. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  2477. f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
  2478. )
  2479. if is_local:
  2480. logger.info(f"loading weights file {archive_file}")
  2481. resolved_archive_file = archive_file
  2482. filename = resolved_archive_file.split(os.path.sep)[-1]
  2483. else:
  2484. logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
  2485. else:
  2486. resolved_archive_file = None
  2487. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
  2488. if is_sharded:
  2489. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
  2490. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
  2491. pretrained_model_name_or_path,
  2492. resolved_archive_file,
  2493. cache_dir=cache_dir,
  2494. force_download=force_download,
  2495. proxies=proxies,
  2496. resume_download=resume_download,
  2497. local_files_only=local_files_only,
  2498. token=token,
  2499. user_agent=user_agent,
  2500. revision=revision,
  2501. _commit_hash=commit_hash,
  2502. )
  2503. safetensors_from_pt = False
  2504. if filename == SAFE_WEIGHTS_NAME:
  2505. with safe_open(resolved_archive_file, framework="tf") as f:
  2506. safetensors_metadata = f.metadata()
  2507. if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  2508. raise OSError(
  2509. f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
  2510. " Make sure you save your model with the `save_pretrained` method."
  2511. )
  2512. safetensors_from_pt = safetensors_metadata.get("format") == "pt"
  2513. elif filename == SAFE_WEIGHTS_INDEX_NAME:
  2514. with safe_open(resolved_archive_file[0], framework="tf") as f:
  2515. safetensors_metadata = f.metadata()
  2516. if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  2517. raise OSError(
  2518. f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
  2519. " Make sure you save your model with the `save_pretrained` method."
  2520. )
  2521. safetensors_from_pt = safetensors_metadata.get("format") == "pt"
  2522. config.name_or_path = pretrained_model_name_or_path
  2523. # composed models, *e.g.* TFRag, require special treatment when it comes to loading
  2524. # pre-trained weights.
  2525. if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
  2526. model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
  2527. # Instantiate model.
  2528. model = cls(config, *model_args, **model_kwargs)
  2529. if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
  2530. # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
  2531. # to be defined for each class that requires a rename. We can probably just have a class-level
  2532. # dict and a single top-level method or something and cut down a lot of boilerplate code
  2533. tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
  2534. if from_pt:
  2535. from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
  2536. # Load from a PyTorch checkpoint
  2537. return load_pytorch_checkpoint_in_tf2_model(
  2538. model,
  2539. resolved_archive_file,
  2540. allow_missing_keys=True,
  2541. output_loading_info=output_loading_info,
  2542. _prefix=load_weight_prefix,
  2543. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2544. )
  2545. # we might need to extend the variable scope for composite models
  2546. if load_weight_prefix is not None:
  2547. with tf.compat.v1.variable_scope(load_weight_prefix):
  2548. model.build_in_name_scope() # build the network with dummy inputs
  2549. else:
  2550. model.build_in_name_scope() # build the network with dummy inputs
  2551. if safetensors_from_pt and not is_sharded:
  2552. from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
  2553. with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
  2554. # Load from a PyTorch safetensors checkpoint
  2555. # We load in TF format here because PT weights often need to be transposed, and this is much
  2556. # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
  2557. return load_pytorch_state_dict_in_tf2_model(
  2558. model,
  2559. safetensors_archive,
  2560. tf_inputs=False, # No need to build the model again
  2561. allow_missing_keys=True,
  2562. output_loading_info=output_loading_info,
  2563. _prefix=load_weight_prefix,
  2564. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2565. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2566. )
  2567. elif safetensors_from_pt:
  2568. from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
  2569. return load_sharded_pytorch_safetensors_in_tf2_model(
  2570. model,
  2571. resolved_archive_file,
  2572. tf_inputs=False,
  2573. allow_missing_keys=True,
  2574. output_loading_info=output_loading_info,
  2575. _prefix=load_weight_prefix,
  2576. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2577. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  2578. )
  2579. # 'by_name' allow us to do transfer learning by skipping/adding layers
  2580. # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
  2581. try:
  2582. if is_sharded:
  2583. for file in resolved_archive_file:
  2584. os.path.isfile(file), f"Error retrieving files {file}"
  2585. if filename == SAFE_WEIGHTS_INDEX_NAME:
  2586. missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
  2587. model,
  2588. resolved_archive_file,
  2589. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2590. _prefix=load_weight_prefix,
  2591. )
  2592. else:
  2593. missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
  2594. model,
  2595. resolved_archive_file,
  2596. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2597. _prefix=load_weight_prefix,
  2598. )
  2599. else:
  2600. # Handles both H5 and safetensors
  2601. missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
  2602. model,
  2603. resolved_archive_file,
  2604. ignore_mismatched_sizes=ignore_mismatched_sizes,
  2605. _prefix=load_weight_prefix,
  2606. )
  2607. except OSError as e:
  2608. try:
  2609. with open(resolved_archive_file) as f:
  2610. if f.read().startswith("version"):
  2611. raise OSError(
  2612. "You seem to have cloned a repository without having git-lfs installed. Please install "
  2613. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  2614. "you cloned."
  2615. )
  2616. else:
  2617. raise ValueError from e
  2618. except (UnicodeDecodeError, ValueError):
  2619. raise OSError(
  2620. "Unable to load weights from h5 file. "
  2621. "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
  2622. )
  2623. if cls._keys_to_ignore_on_load_missing is not None:
  2624. for pat in cls._keys_to_ignore_on_load_missing:
  2625. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  2626. if cls._keys_to_ignore_on_load_unexpected is not None:
  2627. for pat in cls._keys_to_ignore_on_load_unexpected:
  2628. unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
  2629. if len(unexpected_keys) > 0:
  2630. logger.warning(
  2631. f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when"
  2632. f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
  2633. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  2634. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  2635. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  2636. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  2637. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  2638. )
  2639. else:
  2640. logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
  2641. if len(missing_keys) > 0:
  2642. logger.warning(
  2643. f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at"
  2644. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  2645. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  2646. )
  2647. elif len(mismatched_keys) == 0:
  2648. logger.warning(
  2649. f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at"
  2650. f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
  2651. f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
  2652. " training."
  2653. )
  2654. if len(mismatched_keys) > 0:
  2655. mismatched_warning = "\n".join(
  2656. [
  2657. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  2658. for key, shape1, shape2 in mismatched_keys
  2659. ]
  2660. )
  2661. logger.warning(
  2662. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  2663. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  2664. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  2665. " to use it for predictions and inference."
  2666. )
  2667. # If it is a model with generation capabilities, attempt to load the generation config
  2668. if model.can_generate():
  2669. try:
  2670. model.generation_config = GenerationConfig.from_pretrained(
  2671. pretrained_model_name_or_path,
  2672. cache_dir=cache_dir,
  2673. force_download=force_download,
  2674. resume_download=resume_download,
  2675. proxies=proxies,
  2676. local_files_only=local_files_only,
  2677. token=token,
  2678. revision=revision,
  2679. subfolder=subfolder,
  2680. _from_auto=from_auto_class,
  2681. _from_pipeline=from_pipeline,
  2682. **kwargs,
  2683. )
  2684. except OSError:
  2685. logger.info(
  2686. "Generation config file not found, using a generation config created from the model config."
  2687. )
  2688. pass
  2689. if output_loading_info:
  2690. loading_info = {
  2691. "missing_keys": missing_keys,
  2692. "unexpected_keys": unexpected_keys,
  2693. "mismatched_keys": mismatched_keys,
  2694. }
  2695. return model, loading_info
  2696. return model
  2697. def push_to_hub(
  2698. self,
  2699. repo_id: str,
  2700. use_temp_dir: Optional[bool] = None,
  2701. commit_message: Optional[str] = None,
  2702. private: Optional[bool] = None,
  2703. max_shard_size: Optional[Union[int, str]] = "10GB",
  2704. token: Optional[Union[bool, str]] = None,
  2705. # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs)
  2706. use_auth_token: Optional[Union[bool, str]] = None,
  2707. create_pr: bool = False,
  2708. **base_model_card_args,
  2709. ) -> str:
  2710. """
  2711. Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
  2712. Parameters:
  2713. repo_id (`str`):
  2714. The name of the repository you want to push your model to. It should contain your organization name
  2715. when pushing to a given organization.
  2716. use_temp_dir (`bool`, *optional*):
  2717. Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
  2718. Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
  2719. commit_message (`str`, *optional*):
  2720. Message to commit while pushing. Will default to `"Upload model"`.
  2721. private (`bool`, *optional*):
  2722. Whether or not the repository created should be private.
  2723. token (`bool` or `str`, *optional*):
  2724. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  2725. when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
  2726. is not specified.
  2727. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  2728. Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
  2729. will then be each of size lower than this size. If expressed as a string, needs to be digits followed
  2730. by a unit (like `"5MB"`).
  2731. create_pr (`bool`, *optional*, defaults to `False`):
  2732. Whether or not to create a PR with the uploaded files or directly commit.
  2733. Examples:
  2734. ```python
  2735. from transformers import TFAutoModel
  2736. model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
  2737. # Push the model to your namespace with the name "my-finetuned-bert".
  2738. model.push_to_hub("my-finetuned-bert")
  2739. # Push the model to an organization with the name "my-finetuned-bert".
  2740. model.push_to_hub("huggingface/my-finetuned-bert")
  2741. ```
  2742. """
  2743. if use_auth_token is not None:
  2744. warnings.warn(
  2745. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  2746. FutureWarning,
  2747. )
  2748. if token is not None:
  2749. raise ValueError(
  2750. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  2751. )
  2752. token = use_auth_token
  2753. if "repo_path_or_name" in base_model_card_args:
  2754. warnings.warn(
  2755. "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
  2756. "`repo_id` instead."
  2757. )
  2758. repo_id = base_model_card_args.pop("repo_path_or_name")
  2759. # Deprecation warning will be sent after for repo_url and organization
  2760. repo_url = base_model_card_args.pop("repo_url", None)
  2761. organization = base_model_card_args.pop("organization", None)
  2762. if os.path.isdir(repo_id):
  2763. working_dir = repo_id
  2764. repo_id = repo_id.split(os.path.sep)[-1]
  2765. else:
  2766. working_dir = repo_id.split("/")[-1]
  2767. repo_id = self._create_repo(
  2768. repo_id, private=private, token=token, repo_url=repo_url, organization=organization
  2769. )
  2770. if use_temp_dir is None:
  2771. use_temp_dir = not os.path.isdir(working_dir)
  2772. with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
  2773. files_timestamps = self._get_files_timestamps(work_dir)
  2774. # Save all files.
  2775. self.save_pretrained(work_dir, max_shard_size=max_shard_size)
  2776. if hasattr(self, "history") and hasattr(self, "create_model_card"):
  2777. # This is a Keras model and we might be able to fish out its History and make a model card out of it
  2778. base_model_card_args = {
  2779. "output_dir": work_dir,
  2780. "model_name": Path(repo_id).name,
  2781. }
  2782. base_model_card_args.update(base_model_card_args)
  2783. self.create_model_card(**base_model_card_args)
  2784. self._upload_modified_files(
  2785. work_dir,
  2786. repo_id,
  2787. files_timestamps,
  2788. commit_message=commit_message,
  2789. token=token,
  2790. create_pr=create_pr,
  2791. )
  2792. @classmethod
  2793. def register_for_auto_class(cls, auto_class="TFAutoModel"):
  2794. """
  2795. Register this class with a given auto class. This should only be used for custom models as the ones in the
  2796. library are already mapped with an auto class.
  2797. <Tip warning={true}>
  2798. This API is experimental and may have some slight breaking changes in the next releases.
  2799. </Tip>
  2800. Args:
  2801. auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
  2802. The auto class to register this new model with.
  2803. """
  2804. if not isinstance(auto_class, str):
  2805. auto_class = auto_class.__name__
  2806. import transformers.models.auto as auto_module
  2807. if not hasattr(auto_module, auto_class):
  2808. raise ValueError(f"{auto_class} is not a valid auto class.")
  2809. cls._auto_class = auto_class
  2810. class TFConv1D(keras.layers.Layer):
  2811. """
  2812. 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
  2813. Basically works like a linear layer but the weights are transposed.
  2814. Args:
  2815. nf (`int`):
  2816. The number of output features.
  2817. nx (`int`):
  2818. The number of input features.
  2819. initializer_range (`float`, *optional*, defaults to 0.02):
  2820. The standard deviation to use to initialize the weights.
  2821. kwargs (`Dict[str, Any]`, *optional*):
  2822. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2823. """
  2824. def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
  2825. super().__init__(**kwargs)
  2826. self.nf = nf
  2827. self.nx = nx
  2828. self.initializer_range = initializer_range
  2829. def build(self, input_shape):
  2830. if self.built:
  2831. return
  2832. self.built = True
  2833. self.weight = self.add_weight(
  2834. "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
  2835. )
  2836. self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
  2837. def call(self, x):
  2838. bz, sl = shape_list(x)[:2]
  2839. x = tf.reshape(x, [-1, self.nx])
  2840. x = tf.matmul(x, self.weight) + self.bias
  2841. x = tf.reshape(x, [bz, sl, self.nf])
  2842. return x
  2843. class TFSharedEmbeddings(keras.layers.Layer):
  2844. r"""
  2845. Construct shared token embeddings.
  2846. The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
  2847. modeling.
  2848. Args:
  2849. vocab_size (`int`):
  2850. The size of the vocabulary, e.g., the number of unique tokens.
  2851. hidden_size (`int`):
  2852. The size of the embedding vectors.
  2853. initializer_range (`float`, *optional*):
  2854. The standard deviation to use when initializing the weights. If no value is provided, it will default to
  2855. \\(1/\sqrt{hidden\_size}\\).
  2856. kwargs (`Dict[str, Any]`, *optional*):
  2857. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2858. """
  2859. # TODO (joao): flagged for delection due to embeddings refactor
  2860. def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
  2861. super().__init__(**kwargs)
  2862. self.vocab_size = vocab_size
  2863. self.hidden_size = hidden_size
  2864. self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
  2865. warnings.warn(
  2866. "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.",
  2867. DeprecationWarning,
  2868. )
  2869. def build(self, input_shape):
  2870. """
  2871. Build shared token embedding layer Shared weights logic adapted from
  2872. https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
  2873. """
  2874. self.weight = self.add_weight(
  2875. "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
  2876. )
  2877. super().build(input_shape)
  2878. def get_config(self):
  2879. config = {
  2880. "vocab_size": self.vocab_size,
  2881. "hidden_size": self.hidden_size,
  2882. "initializer_range": self.initializer_range,
  2883. }
  2884. base_config = super().get_config()
  2885. return dict(list(base_config.items()) + list(config.items()))
  2886. def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
  2887. """
  2888. Get token embeddings of inputs or decode final hidden state.
  2889. Args:
  2890. inputs (`tf.Tensor`):
  2891. In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
  2892. In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
  2893. mode (`str`, defaults to `"embedding"`):
  2894. A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be
  2895. used as an embedding layer, the second one that the layer should be used as a linear decoder.
  2896. Returns:
  2897. `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length,
  2898. embedding_size]`.
  2899. In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`.
  2900. Raises:
  2901. ValueError: if `mode` is not valid.
  2902. Shared weights logic is adapted from
  2903. [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24).
  2904. """
  2905. if mode == "embedding":
  2906. return self._embedding(inputs)
  2907. elif mode == "linear":
  2908. return self._linear(inputs)
  2909. else:
  2910. raise ValueError(f"mode {mode} is not valid.")
  2911. def _embedding(self, input_ids):
  2912. """Applies embedding based on inputs tensor."""
  2913. return tf.gather(self.weight, input_ids)
  2914. def _linear(self, inputs):
  2915. """
  2916. Computes logits by running inputs through a linear layer.
  2917. Args:
  2918. inputs: A float32 tensor with shape [..., hidden_size]
  2919. Returns:
  2920. float32 tensor with shape [..., vocab_size].
  2921. """
  2922. first_dims = shape_list(inputs)[:-1]
  2923. x = tf.reshape(inputs, [-1, self.hidden_size])
  2924. logits = tf.matmul(x, self.weight, transpose_b=True)
  2925. return tf.reshape(logits, first_dims + [self.vocab_size])
  2926. class TFSequenceSummary(keras.layers.Layer):
  2927. """
  2928. Compute a single vector summary of a sequence hidden states.
  2929. Args:
  2930. config ([`PretrainedConfig`]):
  2931. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  2932. config class of your model for the default values it uses):
  2933. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  2934. - `"last"` -- Take the last token hidden state (like XLNet)
  2935. - `"first"` -- Take the first token hidden state (like Bert)
  2936. - `"mean"` -- Take the mean of all tokens hidden states
  2937. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  2938. - `"attn"` -- Not implemented now, use multi-head attention
  2939. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  2940. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  2941. (otherwise to `config.hidden_size`).
  2942. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  2943. another string or `None` will add no activation.
  2944. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  2945. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  2946. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights.
  2947. kwargs (`Dict[str, Any]`, *optional*):
  2948. Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
  2949. """
  2950. def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
  2951. super().__init__(**kwargs)
  2952. self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
  2953. if self.summary_type == "attn":
  2954. # We should use a standard multi-head attention module with absolute positional embedding for that.
  2955. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  2956. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  2957. raise NotImplementedError
  2958. self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
  2959. if self.has_summary:
  2960. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  2961. num_classes = config.num_labels
  2962. else:
  2963. num_classes = config.hidden_size
  2964. self.summary = keras.layers.Dense(
  2965. num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
  2966. )
  2967. self.has_activation = False
  2968. activation_string = getattr(config, "summary_activation", None)
  2969. if activation_string is not None:
  2970. self.has_activation = True
  2971. self.activation = get_tf_activation(activation_string)
  2972. self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
  2973. if self.has_first_dropout:
  2974. self.first_dropout = keras.layers.Dropout(config.summary_first_dropout)
  2975. self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
  2976. if self.has_last_dropout:
  2977. self.last_dropout = keras.layers.Dropout(config.summary_last_dropout)
  2978. self.hidden_size = config.hidden_size
  2979. def call(self, inputs, cls_index=None, training=False):
  2980. if not isinstance(inputs, (dict, tuple, list)):
  2981. hidden_states = inputs
  2982. elif isinstance(inputs, (tuple, list)):
  2983. hidden_states = inputs[0]
  2984. cls_index = inputs[1] if len(inputs) > 1 else None
  2985. assert len(inputs) <= 2, "Too many inputs."
  2986. else:
  2987. hidden_states = inputs.get("hidden_states")
  2988. cls_index = inputs.get("cls_index", None)
  2989. if self.summary_type == "last":
  2990. output = hidden_states[:, -1]
  2991. elif self.summary_type == "first":
  2992. output = hidden_states[:, 0]
  2993. elif self.summary_type == "mean":
  2994. output = tf.reduce_mean(hidden_states, axis=1)
  2995. elif self.summary_type == "cls_index":
  2996. hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
  2997. if cls_index is None:
  2998. cls_index = tf.fill(
  2999. hidden_shape[:-2], hidden_shape[-2] - 1
  3000. ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
  3001. cls_shape = shape_list(cls_index)
  3002. if len(cls_shape) <= len(hidden_shape) - 2:
  3003. cls_index = tf.expand_dims(cls_index, axis=-1)
  3004. # else:
  3005. # cls_index = cls_index[..., tf.newaxis]
  3006. # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
  3007. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  3008. output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
  3009. output = tf.squeeze(
  3010. output, axis=len(hidden_shape) - 2
  3011. ) # shape of output: (batch, num choices, hidden_size)
  3012. elif self.summary_type == "attn":
  3013. raise NotImplementedError
  3014. if self.has_first_dropout:
  3015. output = self.first_dropout(output, training=training)
  3016. if self.has_summary:
  3017. output = self.summary(output)
  3018. if self.has_activation:
  3019. output = self.activation(output)
  3020. if self.has_last_dropout:
  3021. output = self.last_dropout(output, training=training)
  3022. return output
  3023. def build(self, input_shape):
  3024. if self.built:
  3025. return
  3026. self.built = True
  3027. if getattr(self, "summary", None) is not None:
  3028. with tf.name_scope("summary"):
  3029. self.summary.build(self.hidden_size)
  3030. def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal:
  3031. """
  3032. Creates a `keras.initializers.TruncatedNormal` with the given range.
  3033. Args:
  3034. initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
  3035. Returns:
  3036. `keras.initializers.TruncatedNormal`: The truncated normal initializer.
  3037. """
  3038. return keras.initializers.TruncatedNormal(stddev=initializer_range)