modeling_auto.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Auto Model class."""
  16. import warnings
  17. from collections import OrderedDict
  18. from ...utils import logging
  19. from .auto_factory import (
  20. _BaseAutoBackboneClass,
  21. _BaseAutoModelClass,
  22. _LazyAutoMapping,
  23. auto_class_update,
  24. )
  25. from .configuration_auto import CONFIG_MAPPING_NAMES
  26. logger = logging.get_logger(__name__)
  27. MODEL_MAPPING_NAMES = OrderedDict(
  28. [
  29. # Base model mapping
  30. ("albert", "AlbertModel"),
  31. ("align", "AlignModel"),
  32. ("altclip", "AltCLIPModel"),
  33. ("audio-spectrogram-transformer", "ASTModel"),
  34. ("autoformer", "AutoformerModel"),
  35. ("bark", "BarkModel"),
  36. ("bart", "BartModel"),
  37. ("beit", "BeitModel"),
  38. ("bert", "BertModel"),
  39. ("bert-generation", "BertGenerationEncoder"),
  40. ("big_bird", "BigBirdModel"),
  41. ("bigbird_pegasus", "BigBirdPegasusModel"),
  42. ("biogpt", "BioGptModel"),
  43. ("bit", "BitModel"),
  44. ("blenderbot", "BlenderbotModel"),
  45. ("blenderbot-small", "BlenderbotSmallModel"),
  46. ("blip", "BlipModel"),
  47. ("blip-2", "Blip2Model"),
  48. ("bloom", "BloomModel"),
  49. ("bridgetower", "BridgeTowerModel"),
  50. ("bros", "BrosModel"),
  51. ("camembert", "CamembertModel"),
  52. ("canine", "CanineModel"),
  53. ("chameleon", "ChameleonModel"),
  54. ("chinese_clip", "ChineseCLIPModel"),
  55. ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
  56. ("clap", "ClapModel"),
  57. ("clip", "CLIPModel"),
  58. ("clip_text_model", "CLIPTextModel"),
  59. ("clip_vision_model", "CLIPVisionModel"),
  60. ("clipseg", "CLIPSegModel"),
  61. ("clvp", "ClvpModelForConditionalGeneration"),
  62. ("code_llama", "LlamaModel"),
  63. ("codegen", "CodeGenModel"),
  64. ("cohere", "CohereModel"),
  65. ("conditional_detr", "ConditionalDetrModel"),
  66. ("convbert", "ConvBertModel"),
  67. ("convnext", "ConvNextModel"),
  68. ("convnextv2", "ConvNextV2Model"),
  69. ("cpmant", "CpmAntModel"),
  70. ("ctrl", "CTRLModel"),
  71. ("cvt", "CvtModel"),
  72. ("dac", "DacModel"),
  73. ("data2vec-audio", "Data2VecAudioModel"),
  74. ("data2vec-text", "Data2VecTextModel"),
  75. ("data2vec-vision", "Data2VecVisionModel"),
  76. ("dbrx", "DbrxModel"),
  77. ("deberta", "DebertaModel"),
  78. ("deberta-v2", "DebertaV2Model"),
  79. ("decision_transformer", "DecisionTransformerModel"),
  80. ("deformable_detr", "DeformableDetrModel"),
  81. ("deit", "DeiTModel"),
  82. ("deta", "DetaModel"),
  83. ("detr", "DetrModel"),
  84. ("dinat", "DinatModel"),
  85. ("dinov2", "Dinov2Model"),
  86. ("distilbert", "DistilBertModel"),
  87. ("donut-swin", "DonutSwinModel"),
  88. ("dpr", "DPRQuestionEncoder"),
  89. ("dpt", "DPTModel"),
  90. ("efficientformer", "EfficientFormerModel"),
  91. ("efficientnet", "EfficientNetModel"),
  92. ("electra", "ElectraModel"),
  93. ("encodec", "EncodecModel"),
  94. ("ernie", "ErnieModel"),
  95. ("ernie_m", "ErnieMModel"),
  96. ("esm", "EsmModel"),
  97. ("falcon", "FalconModel"),
  98. ("falcon_mamba", "FalconMambaModel"),
  99. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  100. ("flaubert", "FlaubertModel"),
  101. ("flava", "FlavaModel"),
  102. ("fnet", "FNetModel"),
  103. ("focalnet", "FocalNetModel"),
  104. ("fsmt", "FSMTModel"),
  105. ("funnel", ("FunnelModel", "FunnelBaseModel")),
  106. ("gemma", "GemmaModel"),
  107. ("gemma2", "Gemma2Model"),
  108. ("git", "GitModel"),
  109. ("glm", "GlmModel"),
  110. ("glpn", "GLPNModel"),
  111. ("gpt-sw3", "GPT2Model"),
  112. ("gpt2", "GPT2Model"),
  113. ("gpt_bigcode", "GPTBigCodeModel"),
  114. ("gpt_neo", "GPTNeoModel"),
  115. ("gpt_neox", "GPTNeoXModel"),
  116. ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
  117. ("gptj", "GPTJModel"),
  118. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  119. ("granite", "GraniteModel"),
  120. ("granitemoe", "GraniteMoeModel"),
  121. ("graphormer", "GraphormerModel"),
  122. ("grounding-dino", "GroundingDinoModel"),
  123. ("groupvit", "GroupViTModel"),
  124. ("hiera", "HieraModel"),
  125. ("hubert", "HubertModel"),
  126. ("ibert", "IBertModel"),
  127. ("idefics", "IdeficsModel"),
  128. ("idefics2", "Idefics2Model"),
  129. ("idefics3", "Idefics3Model"),
  130. ("imagegpt", "ImageGPTModel"),
  131. ("informer", "InformerModel"),
  132. ("jamba", "JambaModel"),
  133. ("jetmoe", "JetMoeModel"),
  134. ("jukebox", "JukeboxModel"),
  135. ("kosmos-2", "Kosmos2Model"),
  136. ("layoutlm", "LayoutLMModel"),
  137. ("layoutlmv2", "LayoutLMv2Model"),
  138. ("layoutlmv3", "LayoutLMv3Model"),
  139. ("led", "LEDModel"),
  140. ("levit", "LevitModel"),
  141. ("lilt", "LiltModel"),
  142. ("llama", "LlamaModel"),
  143. ("longformer", "LongformerModel"),
  144. ("longt5", "LongT5Model"),
  145. ("luke", "LukeModel"),
  146. ("lxmert", "LxmertModel"),
  147. ("m2m_100", "M2M100Model"),
  148. ("mamba", "MambaModel"),
  149. ("mamba2", "Mamba2Model"),
  150. ("marian", "MarianModel"),
  151. ("markuplm", "MarkupLMModel"),
  152. ("mask2former", "Mask2FormerModel"),
  153. ("maskformer", "MaskFormerModel"),
  154. ("maskformer-swin", "MaskFormerSwinModel"),
  155. ("mbart", "MBartModel"),
  156. ("mctct", "MCTCTModel"),
  157. ("mega", "MegaModel"),
  158. ("megatron-bert", "MegatronBertModel"),
  159. ("mgp-str", "MgpstrForSceneTextRecognition"),
  160. ("mimi", "MimiModel"),
  161. ("mistral", "MistralModel"),
  162. ("mixtral", "MixtralModel"),
  163. ("mobilebert", "MobileBertModel"),
  164. ("mobilenet_v1", "MobileNetV1Model"),
  165. ("mobilenet_v2", "MobileNetV2Model"),
  166. ("mobilevit", "MobileViTModel"),
  167. ("mobilevitv2", "MobileViTV2Model"),
  168. ("moshi", "MoshiModel"),
  169. ("mpnet", "MPNetModel"),
  170. ("mpt", "MptModel"),
  171. ("mra", "MraModel"),
  172. ("mt5", "MT5Model"),
  173. ("musicgen", "MusicgenModel"),
  174. ("musicgen_melody", "MusicgenMelodyModel"),
  175. ("mvp", "MvpModel"),
  176. ("nat", "NatModel"),
  177. ("nemotron", "NemotronModel"),
  178. ("nezha", "NezhaModel"),
  179. ("nllb-moe", "NllbMoeModel"),
  180. ("nystromformer", "NystromformerModel"),
  181. ("olmo", "OlmoModel"),
  182. ("olmoe", "OlmoeModel"),
  183. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  184. ("oneformer", "OneFormerModel"),
  185. ("open-llama", "OpenLlamaModel"),
  186. ("openai-gpt", "OpenAIGPTModel"),
  187. ("opt", "OPTModel"),
  188. ("owlv2", "Owlv2Model"),
  189. ("owlvit", "OwlViTModel"),
  190. ("patchtsmixer", "PatchTSMixerModel"),
  191. ("patchtst", "PatchTSTModel"),
  192. ("pegasus", "PegasusModel"),
  193. ("pegasus_x", "PegasusXModel"),
  194. ("perceiver", "PerceiverModel"),
  195. ("persimmon", "PersimmonModel"),
  196. ("phi", "PhiModel"),
  197. ("phi3", "Phi3Model"),
  198. ("phimoe", "PhimoeModel"),
  199. ("pixtral", "PixtralVisionModel"),
  200. ("plbart", "PLBartModel"),
  201. ("poolformer", "PoolFormerModel"),
  202. ("prophetnet", "ProphetNetModel"),
  203. ("pvt", "PvtModel"),
  204. ("pvt_v2", "PvtV2Model"),
  205. ("qdqbert", "QDQBertModel"),
  206. ("qwen2", "Qwen2Model"),
  207. ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
  208. ("qwen2_moe", "Qwen2MoeModel"),
  209. ("qwen2_vl", "Qwen2VLModel"),
  210. ("recurrent_gemma", "RecurrentGemmaModel"),
  211. ("reformer", "ReformerModel"),
  212. ("regnet", "RegNetModel"),
  213. ("rembert", "RemBertModel"),
  214. ("resnet", "ResNetModel"),
  215. ("retribert", "RetriBertModel"),
  216. ("roberta", "RobertaModel"),
  217. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  218. ("roc_bert", "RoCBertModel"),
  219. ("roformer", "RoFormerModel"),
  220. ("rt_detr", "RTDetrModel"),
  221. ("rwkv", "RwkvModel"),
  222. ("sam", "SamModel"),
  223. ("seamless_m4t", "SeamlessM4TModel"),
  224. ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
  225. ("segformer", "SegformerModel"),
  226. ("seggpt", "SegGptModel"),
  227. ("sew", "SEWModel"),
  228. ("sew-d", "SEWDModel"),
  229. ("siglip", "SiglipModel"),
  230. ("siglip_vision_model", "SiglipVisionModel"),
  231. ("speech_to_text", "Speech2TextModel"),
  232. ("speecht5", "SpeechT5Model"),
  233. ("splinter", "SplinterModel"),
  234. ("squeezebert", "SqueezeBertModel"),
  235. ("stablelm", "StableLmModel"),
  236. ("starcoder2", "Starcoder2Model"),
  237. ("swiftformer", "SwiftFormerModel"),
  238. ("swin", "SwinModel"),
  239. ("swin2sr", "Swin2SRModel"),
  240. ("swinv2", "Swinv2Model"),
  241. ("switch_transformers", "SwitchTransformersModel"),
  242. ("t5", "T5Model"),
  243. ("table-transformer", "TableTransformerModel"),
  244. ("tapas", "TapasModel"),
  245. ("time_series_transformer", "TimeSeriesTransformerModel"),
  246. ("timesformer", "TimesformerModel"),
  247. ("timm_backbone", "TimmBackbone"),
  248. ("trajectory_transformer", "TrajectoryTransformerModel"),
  249. ("transfo-xl", "TransfoXLModel"),
  250. ("tvlt", "TvltModel"),
  251. ("tvp", "TvpModel"),
  252. ("udop", "UdopModel"),
  253. ("umt5", "UMT5Model"),
  254. ("unispeech", "UniSpeechModel"),
  255. ("unispeech-sat", "UniSpeechSatModel"),
  256. ("univnet", "UnivNetModel"),
  257. ("van", "VanModel"),
  258. ("videomae", "VideoMAEModel"),
  259. ("vilt", "ViltModel"),
  260. ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
  261. ("visual_bert", "VisualBertModel"),
  262. ("vit", "ViTModel"),
  263. ("vit_hybrid", "ViTHybridModel"),
  264. ("vit_mae", "ViTMAEModel"),
  265. ("vit_msn", "ViTMSNModel"),
  266. ("vitdet", "VitDetModel"),
  267. ("vits", "VitsModel"),
  268. ("vivit", "VivitModel"),
  269. ("wav2vec2", "Wav2Vec2Model"),
  270. ("wav2vec2-bert", "Wav2Vec2BertModel"),
  271. ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
  272. ("wavlm", "WavLMModel"),
  273. ("whisper", "WhisperModel"),
  274. ("xclip", "XCLIPModel"),
  275. ("xglm", "XGLMModel"),
  276. ("xlm", "XLMModel"),
  277. ("xlm-prophetnet", "XLMProphetNetModel"),
  278. ("xlm-roberta", "XLMRobertaModel"),
  279. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  280. ("xlnet", "XLNetModel"),
  281. ("xmod", "XmodModel"),
  282. ("yolos", "YolosModel"),
  283. ("yoso", "YosoModel"),
  284. ("zamba", "ZambaModel"),
  285. ]
  286. )
  287. MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
  288. [
  289. # Model for pre-training mapping
  290. ("albert", "AlbertForPreTraining"),
  291. ("bart", "BartForConditionalGeneration"),
  292. ("bert", "BertForPreTraining"),
  293. ("big_bird", "BigBirdForPreTraining"),
  294. ("bloom", "BloomForCausalLM"),
  295. ("camembert", "CamembertForMaskedLM"),
  296. ("ctrl", "CTRLLMHeadModel"),
  297. ("data2vec-text", "Data2VecTextForMaskedLM"),
  298. ("deberta", "DebertaForMaskedLM"),
  299. ("deberta-v2", "DebertaV2ForMaskedLM"),
  300. ("distilbert", "DistilBertForMaskedLM"),
  301. ("electra", "ElectraForPreTraining"),
  302. ("ernie", "ErnieForPreTraining"),
  303. ("falcon_mamba", "FalconMambaForCausalLM"),
  304. ("flaubert", "FlaubertWithLMHeadModel"),
  305. ("flava", "FlavaForPreTraining"),
  306. ("fnet", "FNetForPreTraining"),
  307. ("fsmt", "FSMTForConditionalGeneration"),
  308. ("funnel", "FunnelForPreTraining"),
  309. ("gpt-sw3", "GPT2LMHeadModel"),
  310. ("gpt2", "GPT2LMHeadModel"),
  311. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  312. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  313. ("hiera", "HieraForPreTraining"),
  314. ("ibert", "IBertForMaskedLM"),
  315. ("idefics", "IdeficsForVisionText2Text"),
  316. ("idefics2", "Idefics2ForConditionalGeneration"),
  317. ("idefics3", "Idefics3ForConditionalGeneration"),
  318. ("layoutlm", "LayoutLMForMaskedLM"),
  319. ("llava", "LlavaForConditionalGeneration"),
  320. ("llava_next", "LlavaNextForConditionalGeneration"),
  321. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  322. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  323. ("longformer", "LongformerForMaskedLM"),
  324. ("luke", "LukeForMaskedLM"),
  325. ("lxmert", "LxmertForPreTraining"),
  326. ("mamba", "MambaForCausalLM"),
  327. ("mamba2", "Mamba2ForCausalLM"),
  328. ("mega", "MegaForMaskedLM"),
  329. ("megatron-bert", "MegatronBertForPreTraining"),
  330. ("mllama", "MllamaForConditionalGeneration"),
  331. ("mobilebert", "MobileBertForPreTraining"),
  332. ("mpnet", "MPNetForMaskedLM"),
  333. ("mpt", "MptForCausalLM"),
  334. ("mra", "MraForMaskedLM"),
  335. ("mvp", "MvpForConditionalGeneration"),
  336. ("nezha", "NezhaForPreTraining"),
  337. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  338. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  339. ("paligemma", "PaliGemmaForConditionalGeneration"),
  340. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  341. ("retribert", "RetriBertModel"),
  342. ("roberta", "RobertaForMaskedLM"),
  343. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  344. ("roc_bert", "RoCBertForPreTraining"),
  345. ("rwkv", "RwkvForCausalLM"),
  346. ("splinter", "SplinterForPreTraining"),
  347. ("squeezebert", "SqueezeBertForMaskedLM"),
  348. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  349. ("t5", "T5ForConditionalGeneration"),
  350. ("tapas", "TapasForMaskedLM"),
  351. ("transfo-xl", "TransfoXLLMHeadModel"),
  352. ("tvlt", "TvltForPreTraining"),
  353. ("unispeech", "UniSpeechForPreTraining"),
  354. ("unispeech-sat", "UniSpeechSatForPreTraining"),
  355. ("video_llava", "VideoLlavaForConditionalGeneration"),
  356. ("videomae", "VideoMAEForPreTraining"),
  357. ("vipllava", "VipLlavaForConditionalGeneration"),
  358. ("visual_bert", "VisualBertForPreTraining"),
  359. ("vit_mae", "ViTMAEForPreTraining"),
  360. ("wav2vec2", "Wav2Vec2ForPreTraining"),
  361. ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
  362. ("xlm", "XLMWithLMHeadModel"),
  363. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  364. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  365. ("xlnet", "XLNetLMHeadModel"),
  366. ("xmod", "XmodForMaskedLM"),
  367. ]
  368. )
  369. MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
  370. [
  371. # Model with LM heads mapping
  372. ("albert", "AlbertForMaskedLM"),
  373. ("bart", "BartForConditionalGeneration"),
  374. ("bert", "BertForMaskedLM"),
  375. ("big_bird", "BigBirdForMaskedLM"),
  376. ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
  377. ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
  378. ("bloom", "BloomForCausalLM"),
  379. ("camembert", "CamembertForMaskedLM"),
  380. ("codegen", "CodeGenForCausalLM"),
  381. ("convbert", "ConvBertForMaskedLM"),
  382. ("cpmant", "CpmAntForCausalLM"),
  383. ("ctrl", "CTRLLMHeadModel"),
  384. ("data2vec-text", "Data2VecTextForMaskedLM"),
  385. ("deberta", "DebertaForMaskedLM"),
  386. ("deberta-v2", "DebertaV2ForMaskedLM"),
  387. ("distilbert", "DistilBertForMaskedLM"),
  388. ("electra", "ElectraForMaskedLM"),
  389. ("encoder-decoder", "EncoderDecoderModel"),
  390. ("ernie", "ErnieForMaskedLM"),
  391. ("esm", "EsmForMaskedLM"),
  392. ("falcon_mamba", "FalconMambaForCausalLM"),
  393. ("flaubert", "FlaubertWithLMHeadModel"),
  394. ("fnet", "FNetForMaskedLM"),
  395. ("fsmt", "FSMTForConditionalGeneration"),
  396. ("funnel", "FunnelForMaskedLM"),
  397. ("git", "GitForCausalLM"),
  398. ("gpt-sw3", "GPT2LMHeadModel"),
  399. ("gpt2", "GPT2LMHeadModel"),
  400. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  401. ("gpt_neo", "GPTNeoForCausalLM"),
  402. ("gpt_neox", "GPTNeoXForCausalLM"),
  403. ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
  404. ("gptj", "GPTJForCausalLM"),
  405. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  406. ("ibert", "IBertForMaskedLM"),
  407. ("layoutlm", "LayoutLMForMaskedLM"),
  408. ("led", "LEDForConditionalGeneration"),
  409. ("longformer", "LongformerForMaskedLM"),
  410. ("longt5", "LongT5ForConditionalGeneration"),
  411. ("luke", "LukeForMaskedLM"),
  412. ("m2m_100", "M2M100ForConditionalGeneration"),
  413. ("mamba", "MambaForCausalLM"),
  414. ("mamba2", "Mamba2ForCausalLM"),
  415. ("marian", "MarianMTModel"),
  416. ("mega", "MegaForMaskedLM"),
  417. ("megatron-bert", "MegatronBertForCausalLM"),
  418. ("mobilebert", "MobileBertForMaskedLM"),
  419. ("mpnet", "MPNetForMaskedLM"),
  420. ("mpt", "MptForCausalLM"),
  421. ("mra", "MraForMaskedLM"),
  422. ("mvp", "MvpForConditionalGeneration"),
  423. ("nezha", "NezhaForMaskedLM"),
  424. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  425. ("nystromformer", "NystromformerForMaskedLM"),
  426. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  427. ("pegasus_x", "PegasusXForConditionalGeneration"),
  428. ("plbart", "PLBartForConditionalGeneration"),
  429. ("pop2piano", "Pop2PianoForConditionalGeneration"),
  430. ("qdqbert", "QDQBertForMaskedLM"),
  431. ("reformer", "ReformerModelWithLMHead"),
  432. ("rembert", "RemBertForMaskedLM"),
  433. ("roberta", "RobertaForMaskedLM"),
  434. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  435. ("roc_bert", "RoCBertForMaskedLM"),
  436. ("roformer", "RoFormerForMaskedLM"),
  437. ("rwkv", "RwkvForCausalLM"),
  438. ("speech_to_text", "Speech2TextForConditionalGeneration"),
  439. ("squeezebert", "SqueezeBertForMaskedLM"),
  440. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  441. ("t5", "T5ForConditionalGeneration"),
  442. ("tapas", "TapasForMaskedLM"),
  443. ("transfo-xl", "TransfoXLLMHeadModel"),
  444. ("wav2vec2", "Wav2Vec2ForMaskedLM"),
  445. ("whisper", "WhisperForConditionalGeneration"),
  446. ("xlm", "XLMWithLMHeadModel"),
  447. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  448. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  449. ("xlnet", "XLNetLMHeadModel"),
  450. ("xmod", "XmodForMaskedLM"),
  451. ("yoso", "YosoForMaskedLM"),
  452. ]
  453. )
  454. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  455. [
  456. # Model for Causal LM mapping
  457. ("bart", "BartForCausalLM"),
  458. ("bert", "BertLMHeadModel"),
  459. ("bert-generation", "BertGenerationDecoder"),
  460. ("big_bird", "BigBirdForCausalLM"),
  461. ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
  462. ("biogpt", "BioGptForCausalLM"),
  463. ("blenderbot", "BlenderbotForCausalLM"),
  464. ("blenderbot-small", "BlenderbotSmallForCausalLM"),
  465. ("bloom", "BloomForCausalLM"),
  466. ("camembert", "CamembertForCausalLM"),
  467. ("code_llama", "LlamaForCausalLM"),
  468. ("codegen", "CodeGenForCausalLM"),
  469. ("cohere", "CohereForCausalLM"),
  470. ("cpmant", "CpmAntForCausalLM"),
  471. ("ctrl", "CTRLLMHeadModel"),
  472. ("data2vec-text", "Data2VecTextForCausalLM"),
  473. ("dbrx", "DbrxForCausalLM"),
  474. ("electra", "ElectraForCausalLM"),
  475. ("ernie", "ErnieForCausalLM"),
  476. ("falcon", "FalconForCausalLM"),
  477. ("falcon_mamba", "FalconMambaForCausalLM"),
  478. ("fuyu", "FuyuForCausalLM"),
  479. ("gemma", "GemmaForCausalLM"),
  480. ("gemma2", "Gemma2ForCausalLM"),
  481. ("git", "GitForCausalLM"),
  482. ("glm", "GlmForCausalLM"),
  483. ("gpt-sw3", "GPT2LMHeadModel"),
  484. ("gpt2", "GPT2LMHeadModel"),
  485. ("gpt_bigcode", "GPTBigCodeForCausalLM"),
  486. ("gpt_neo", "GPTNeoForCausalLM"),
  487. ("gpt_neox", "GPTNeoXForCausalLM"),
  488. ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
  489. ("gptj", "GPTJForCausalLM"),
  490. ("granite", "GraniteForCausalLM"),
  491. ("granitemoe", "GraniteMoeForCausalLM"),
  492. ("jamba", "JambaForCausalLM"),
  493. ("jetmoe", "JetMoeForCausalLM"),
  494. ("llama", "LlamaForCausalLM"),
  495. ("mamba", "MambaForCausalLM"),
  496. ("mamba2", "Mamba2ForCausalLM"),
  497. ("marian", "MarianForCausalLM"),
  498. ("mbart", "MBartForCausalLM"),
  499. ("mega", "MegaForCausalLM"),
  500. ("megatron-bert", "MegatronBertForCausalLM"),
  501. ("mistral", "MistralForCausalLM"),
  502. ("mixtral", "MixtralForCausalLM"),
  503. ("mllama", "MllamaForCausalLM"),
  504. ("moshi", "MoshiForCausalLM"),
  505. ("mpt", "MptForCausalLM"),
  506. ("musicgen", "MusicgenForCausalLM"),
  507. ("musicgen_melody", "MusicgenMelodyForCausalLM"),
  508. ("mvp", "MvpForCausalLM"),
  509. ("nemotron", "NemotronForCausalLM"),
  510. ("olmo", "OlmoForCausalLM"),
  511. ("olmoe", "OlmoeForCausalLM"),
  512. ("open-llama", "OpenLlamaForCausalLM"),
  513. ("openai-gpt", "OpenAIGPTLMHeadModel"),
  514. ("opt", "OPTForCausalLM"),
  515. ("pegasus", "PegasusForCausalLM"),
  516. ("persimmon", "PersimmonForCausalLM"),
  517. ("phi", "PhiForCausalLM"),
  518. ("phi3", "Phi3ForCausalLM"),
  519. ("phimoe", "PhimoeForCausalLM"),
  520. ("plbart", "PLBartForCausalLM"),
  521. ("prophetnet", "ProphetNetForCausalLM"),
  522. ("qdqbert", "QDQBertLMHeadModel"),
  523. ("qwen2", "Qwen2ForCausalLM"),
  524. ("qwen2_moe", "Qwen2MoeForCausalLM"),
  525. ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
  526. ("reformer", "ReformerModelWithLMHead"),
  527. ("rembert", "RemBertForCausalLM"),
  528. ("roberta", "RobertaForCausalLM"),
  529. ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
  530. ("roc_bert", "RoCBertForCausalLM"),
  531. ("roformer", "RoFormerForCausalLM"),
  532. ("rwkv", "RwkvForCausalLM"),
  533. ("speech_to_text_2", "Speech2Text2ForCausalLM"),
  534. ("stablelm", "StableLmForCausalLM"),
  535. ("starcoder2", "Starcoder2ForCausalLM"),
  536. ("transfo-xl", "TransfoXLLMHeadModel"),
  537. ("trocr", "TrOCRForCausalLM"),
  538. ("whisper", "WhisperForCausalLM"),
  539. ("xglm", "XGLMForCausalLM"),
  540. ("xlm", "XLMWithLMHeadModel"),
  541. ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
  542. ("xlm-roberta", "XLMRobertaForCausalLM"),
  543. ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
  544. ("xlnet", "XLNetLMHeadModel"),
  545. ("xmod", "XmodForCausalLM"),
  546. ("zamba", "ZambaForCausalLM"),
  547. ]
  548. )
  549. MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
  550. [
  551. # Model for Image mapping
  552. ("beit", "BeitModel"),
  553. ("bit", "BitModel"),
  554. ("conditional_detr", "ConditionalDetrModel"),
  555. ("convnext", "ConvNextModel"),
  556. ("convnextv2", "ConvNextV2Model"),
  557. ("data2vec-vision", "Data2VecVisionModel"),
  558. ("deformable_detr", "DeformableDetrModel"),
  559. ("deit", "DeiTModel"),
  560. ("deta", "DetaModel"),
  561. ("detr", "DetrModel"),
  562. ("dinat", "DinatModel"),
  563. ("dinov2", "Dinov2Model"),
  564. ("dpt", "DPTModel"),
  565. ("efficientformer", "EfficientFormerModel"),
  566. ("efficientnet", "EfficientNetModel"),
  567. ("focalnet", "FocalNetModel"),
  568. ("glpn", "GLPNModel"),
  569. ("hiera", "HieraModel"),
  570. ("imagegpt", "ImageGPTModel"),
  571. ("levit", "LevitModel"),
  572. ("mllama", "MllamaVisionModel"),
  573. ("mobilenet_v1", "MobileNetV1Model"),
  574. ("mobilenet_v2", "MobileNetV2Model"),
  575. ("mobilevit", "MobileViTModel"),
  576. ("mobilevitv2", "MobileViTV2Model"),
  577. ("nat", "NatModel"),
  578. ("poolformer", "PoolFormerModel"),
  579. ("pvt", "PvtModel"),
  580. ("regnet", "RegNetModel"),
  581. ("resnet", "ResNetModel"),
  582. ("segformer", "SegformerModel"),
  583. ("siglip_vision_model", "SiglipVisionModel"),
  584. ("swiftformer", "SwiftFormerModel"),
  585. ("swin", "SwinModel"),
  586. ("swin2sr", "Swin2SRModel"),
  587. ("swinv2", "Swinv2Model"),
  588. ("table-transformer", "TableTransformerModel"),
  589. ("timesformer", "TimesformerModel"),
  590. ("timm_backbone", "TimmBackbone"),
  591. ("van", "VanModel"),
  592. ("videomae", "VideoMAEModel"),
  593. ("vit", "ViTModel"),
  594. ("vit_hybrid", "ViTHybridModel"),
  595. ("vit_mae", "ViTMAEModel"),
  596. ("vit_msn", "ViTMSNModel"),
  597. ("vitdet", "VitDetModel"),
  598. ("vivit", "VivitModel"),
  599. ("yolos", "YolosModel"),
  600. ]
  601. )
  602. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  603. [
  604. ("deit", "DeiTForMaskedImageModeling"),
  605. ("focalnet", "FocalNetForMaskedImageModeling"),
  606. ("swin", "SwinForMaskedImageModeling"),
  607. ("swinv2", "Swinv2ForMaskedImageModeling"),
  608. ("vit", "ViTForMaskedImageModeling"),
  609. ]
  610. )
  611. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  612. # Model for Causal Image Modeling mapping
  613. [
  614. ("imagegpt", "ImageGPTForCausalImageModeling"),
  615. ]
  616. )
  617. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  618. [
  619. # Model for Image Classification mapping
  620. ("beit", "BeitForImageClassification"),
  621. ("bit", "BitForImageClassification"),
  622. ("clip", "CLIPForImageClassification"),
  623. ("convnext", "ConvNextForImageClassification"),
  624. ("convnextv2", "ConvNextV2ForImageClassification"),
  625. ("cvt", "CvtForImageClassification"),
  626. ("data2vec-vision", "Data2VecVisionForImageClassification"),
  627. (
  628. "deit",
  629. ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
  630. ),
  631. ("dinat", "DinatForImageClassification"),
  632. ("dinov2", "Dinov2ForImageClassification"),
  633. (
  634. "efficientformer",
  635. (
  636. "EfficientFormerForImageClassification",
  637. "EfficientFormerForImageClassificationWithTeacher",
  638. ),
  639. ),
  640. ("efficientnet", "EfficientNetForImageClassification"),
  641. ("focalnet", "FocalNetForImageClassification"),
  642. ("hiera", "HieraForImageClassification"),
  643. ("imagegpt", "ImageGPTForImageClassification"),
  644. (
  645. "levit",
  646. ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
  647. ),
  648. ("mobilenet_v1", "MobileNetV1ForImageClassification"),
  649. ("mobilenet_v2", "MobileNetV2ForImageClassification"),
  650. ("mobilevit", "MobileViTForImageClassification"),
  651. ("mobilevitv2", "MobileViTV2ForImageClassification"),
  652. ("nat", "NatForImageClassification"),
  653. (
  654. "perceiver",
  655. (
  656. "PerceiverForImageClassificationLearned",
  657. "PerceiverForImageClassificationFourier",
  658. "PerceiverForImageClassificationConvProcessing",
  659. ),
  660. ),
  661. ("poolformer", "PoolFormerForImageClassification"),
  662. ("pvt", "PvtForImageClassification"),
  663. ("pvt_v2", "PvtV2ForImageClassification"),
  664. ("regnet", "RegNetForImageClassification"),
  665. ("resnet", "ResNetForImageClassification"),
  666. ("segformer", "SegformerForImageClassification"),
  667. ("siglip", "SiglipForImageClassification"),
  668. ("swiftformer", "SwiftFormerForImageClassification"),
  669. ("swin", "SwinForImageClassification"),
  670. ("swinv2", "Swinv2ForImageClassification"),
  671. ("van", "VanForImageClassification"),
  672. ("vit", "ViTForImageClassification"),
  673. ("vit_hybrid", "ViTHybridForImageClassification"),
  674. ("vit_msn", "ViTMSNForImageClassification"),
  675. ]
  676. )
  677. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  678. [
  679. # Do not add new models here, this class will be deprecated in the future.
  680. # Model for Image Segmentation mapping
  681. ("detr", "DetrForSegmentation"),
  682. ]
  683. )
  684. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  685. [
  686. # Model for Semantic Segmentation mapping
  687. ("beit", "BeitForSemanticSegmentation"),
  688. ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
  689. ("dpt", "DPTForSemanticSegmentation"),
  690. ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
  691. ("mobilevit", "MobileViTForSemanticSegmentation"),
  692. ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
  693. ("segformer", "SegformerForSemanticSegmentation"),
  694. ("upernet", "UperNetForSemanticSegmentation"),
  695. ]
  696. )
  697. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  698. [
  699. # Model for Instance Segmentation mapping
  700. # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
  701. ("maskformer", "MaskFormerForInstanceSegmentation"),
  702. ]
  703. )
  704. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  705. [
  706. # Model for Universal Segmentation mapping
  707. ("detr", "DetrForSegmentation"),
  708. ("mask2former", "Mask2FormerForUniversalSegmentation"),
  709. ("maskformer", "MaskFormerForInstanceSegmentation"),
  710. ("oneformer", "OneFormerForUniversalSegmentation"),
  711. ]
  712. )
  713. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  714. [
  715. ("timesformer", "TimesformerForVideoClassification"),
  716. ("videomae", "VideoMAEForVideoClassification"),
  717. ("vivit", "VivitForVideoClassification"),
  718. ]
  719. )
  720. MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
  721. [
  722. ("blip", "BlipForConditionalGeneration"),
  723. ("blip-2", "Blip2ForConditionalGeneration"),
  724. ("chameleon", "ChameleonForConditionalGeneration"),
  725. ("git", "GitForCausalLM"),
  726. ("idefics2", "Idefics2ForConditionalGeneration"),
  727. ("idefics3", "Idefics3ForConditionalGeneration"),
  728. ("instructblip", "InstructBlipForConditionalGeneration"),
  729. ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
  730. ("kosmos-2", "Kosmos2ForConditionalGeneration"),
  731. ("llava", "LlavaForConditionalGeneration"),
  732. ("llava_next", "LlavaNextForConditionalGeneration"),
  733. ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
  734. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  735. ("mllama", "MllamaForConditionalGeneration"),
  736. ("paligemma", "PaliGemmaForConditionalGeneration"),
  737. ("pix2struct", "Pix2StructForConditionalGeneration"),
  738. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  739. ("video_llava", "VideoLlavaForConditionalGeneration"),
  740. ("vipllava", "VipLlavaForConditionalGeneration"),
  741. ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
  742. ]
  743. )
  744. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
  745. [
  746. ("blip", "BlipForConditionalGeneration"),
  747. ("blip-2", "Blip2ForConditionalGeneration"),
  748. ("chameleon", "ChameleonForConditionalGeneration"),
  749. ("fuyu", "FuyuForCausalLM"),
  750. ("git", "GitForCausalLM"),
  751. ("idefics", "IdeficsForVisionText2Text"),
  752. ("idefics2", "Idefics2ForConditionalGeneration"),
  753. ("idefics3", "Idefics3ForConditionalGeneration"),
  754. ("instructblip", "InstructBlipForConditionalGeneration"),
  755. ("kosmos-2", "Kosmos2ForConditionalGeneration"),
  756. ("llava", "LlavaForConditionalGeneration"),
  757. ("llava_next", "LlavaNextForConditionalGeneration"),
  758. ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
  759. ("mllama", "MllamaForConditionalGeneration"),
  760. ("paligemma", "PaliGemmaForConditionalGeneration"),
  761. ("pix2struct", "Pix2StructForConditionalGeneration"),
  762. ("pixtral", "LlavaForConditionalGeneration"),
  763. ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
  764. ("udop", "UdopForConditionalGeneration"),
  765. ("vipllava", "VipLlavaForConditionalGeneration"),
  766. ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
  767. ]
  768. )
  769. MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
  770. [
  771. # Model for Masked LM mapping
  772. ("albert", "AlbertForMaskedLM"),
  773. ("bart", "BartForConditionalGeneration"),
  774. ("bert", "BertForMaskedLM"),
  775. ("big_bird", "BigBirdForMaskedLM"),
  776. ("camembert", "CamembertForMaskedLM"),
  777. ("convbert", "ConvBertForMaskedLM"),
  778. ("data2vec-text", "Data2VecTextForMaskedLM"),
  779. ("deberta", "DebertaForMaskedLM"),
  780. ("deberta-v2", "DebertaV2ForMaskedLM"),
  781. ("distilbert", "DistilBertForMaskedLM"),
  782. ("electra", "ElectraForMaskedLM"),
  783. ("ernie", "ErnieForMaskedLM"),
  784. ("esm", "EsmForMaskedLM"),
  785. ("flaubert", "FlaubertWithLMHeadModel"),
  786. ("fnet", "FNetForMaskedLM"),
  787. ("funnel", "FunnelForMaskedLM"),
  788. ("ibert", "IBertForMaskedLM"),
  789. ("layoutlm", "LayoutLMForMaskedLM"),
  790. ("longformer", "LongformerForMaskedLM"),
  791. ("luke", "LukeForMaskedLM"),
  792. ("mbart", "MBartForConditionalGeneration"),
  793. ("mega", "MegaForMaskedLM"),
  794. ("megatron-bert", "MegatronBertForMaskedLM"),
  795. ("mobilebert", "MobileBertForMaskedLM"),
  796. ("mpnet", "MPNetForMaskedLM"),
  797. ("mra", "MraForMaskedLM"),
  798. ("mvp", "MvpForConditionalGeneration"),
  799. ("nezha", "NezhaForMaskedLM"),
  800. ("nystromformer", "NystromformerForMaskedLM"),
  801. ("perceiver", "PerceiverForMaskedLM"),
  802. ("qdqbert", "QDQBertForMaskedLM"),
  803. ("reformer", "ReformerForMaskedLM"),
  804. ("rembert", "RemBertForMaskedLM"),
  805. ("roberta", "RobertaForMaskedLM"),
  806. ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
  807. ("roc_bert", "RoCBertForMaskedLM"),
  808. ("roformer", "RoFormerForMaskedLM"),
  809. ("squeezebert", "SqueezeBertForMaskedLM"),
  810. ("tapas", "TapasForMaskedLM"),
  811. ("wav2vec2", "Wav2Vec2ForMaskedLM"),
  812. ("xlm", "XLMWithLMHeadModel"),
  813. ("xlm-roberta", "XLMRobertaForMaskedLM"),
  814. ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
  815. ("xmod", "XmodForMaskedLM"),
  816. ("yoso", "YosoForMaskedLM"),
  817. ]
  818. )
  819. MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  820. [
  821. # Model for Object Detection mapping
  822. ("conditional_detr", "ConditionalDetrForObjectDetection"),
  823. ("deformable_detr", "DeformableDetrForObjectDetection"),
  824. ("deta", "DetaForObjectDetection"),
  825. ("detr", "DetrForObjectDetection"),
  826. ("rt_detr", "RTDetrForObjectDetection"),
  827. ("table-transformer", "TableTransformerForObjectDetection"),
  828. ("yolos", "YolosForObjectDetection"),
  829. ]
  830. )
  831. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
  832. [
  833. # Model for Zero Shot Object Detection mapping
  834. ("grounding-dino", "GroundingDinoForObjectDetection"),
  835. ("omdet-turbo", "OmDetTurboForObjectDetection"),
  836. ("owlv2", "Owlv2ForObjectDetection"),
  837. ("owlvit", "OwlViTForObjectDetection"),
  838. ]
  839. )
  840. MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
  841. [
  842. # Model for depth estimation mapping
  843. ("depth_anything", "DepthAnythingForDepthEstimation"),
  844. ("dpt", "DPTForDepthEstimation"),
  845. ("glpn", "GLPNForDepthEstimation"),
  846. ("zoedepth", "ZoeDepthForDepthEstimation"),
  847. ]
  848. )
  849. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  850. [
  851. # Model for Seq2Seq Causal LM mapping
  852. ("bart", "BartForConditionalGeneration"),
  853. ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
  854. ("blenderbot", "BlenderbotForConditionalGeneration"),
  855. ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
  856. ("encoder-decoder", "EncoderDecoderModel"),
  857. ("fsmt", "FSMTForConditionalGeneration"),
  858. ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
  859. ("led", "LEDForConditionalGeneration"),
  860. ("longt5", "LongT5ForConditionalGeneration"),
  861. ("m2m_100", "M2M100ForConditionalGeneration"),
  862. ("marian", "MarianMTModel"),
  863. ("mbart", "MBartForConditionalGeneration"),
  864. ("mt5", "MT5ForConditionalGeneration"),
  865. ("mvp", "MvpForConditionalGeneration"),
  866. ("nllb-moe", "NllbMoeForConditionalGeneration"),
  867. ("pegasus", "PegasusForConditionalGeneration"),
  868. ("pegasus_x", "PegasusXForConditionalGeneration"),
  869. ("plbart", "PLBartForConditionalGeneration"),
  870. ("prophetnet", "ProphetNetForConditionalGeneration"),
  871. ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
  872. ("seamless_m4t", "SeamlessM4TForTextToText"),
  873. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
  874. ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
  875. ("t5", "T5ForConditionalGeneration"),
  876. ("umt5", "UMT5ForConditionalGeneration"),
  877. ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
  878. ]
  879. )
  880. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
  881. [
  882. ("pop2piano", "Pop2PianoForConditionalGeneration"),
  883. ("seamless_m4t", "SeamlessM4TForSpeechToText"),
  884. ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
  885. ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
  886. ("speech_to_text", "Speech2TextForConditionalGeneration"),
  887. ("speecht5", "SpeechT5ForSpeechToText"),
  888. ("whisper", "WhisperForConditionalGeneration"),
  889. ]
  890. )
  891. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  892. [
  893. # Model for Sequence Classification mapping
  894. ("albert", "AlbertForSequenceClassification"),
  895. ("bart", "BartForSequenceClassification"),
  896. ("bert", "BertForSequenceClassification"),
  897. ("big_bird", "BigBirdForSequenceClassification"),
  898. ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
  899. ("biogpt", "BioGptForSequenceClassification"),
  900. ("bloom", "BloomForSequenceClassification"),
  901. ("camembert", "CamembertForSequenceClassification"),
  902. ("canine", "CanineForSequenceClassification"),
  903. ("code_llama", "LlamaForSequenceClassification"),
  904. ("convbert", "ConvBertForSequenceClassification"),
  905. ("ctrl", "CTRLForSequenceClassification"),
  906. ("data2vec-text", "Data2VecTextForSequenceClassification"),
  907. ("deberta", "DebertaForSequenceClassification"),
  908. ("deberta-v2", "DebertaV2ForSequenceClassification"),
  909. ("distilbert", "DistilBertForSequenceClassification"),
  910. ("electra", "ElectraForSequenceClassification"),
  911. ("ernie", "ErnieForSequenceClassification"),
  912. ("ernie_m", "ErnieMForSequenceClassification"),
  913. ("esm", "EsmForSequenceClassification"),
  914. ("falcon", "FalconForSequenceClassification"),
  915. ("flaubert", "FlaubertForSequenceClassification"),
  916. ("fnet", "FNetForSequenceClassification"),
  917. ("funnel", "FunnelForSequenceClassification"),
  918. ("gemma", "GemmaForSequenceClassification"),
  919. ("gemma2", "Gemma2ForSequenceClassification"),
  920. ("glm", "GlmForSequenceClassification"),
  921. ("gpt-sw3", "GPT2ForSequenceClassification"),
  922. ("gpt2", "GPT2ForSequenceClassification"),
  923. ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
  924. ("gpt_neo", "GPTNeoForSequenceClassification"),
  925. ("gpt_neox", "GPTNeoXForSequenceClassification"),
  926. ("gptj", "GPTJForSequenceClassification"),
  927. ("ibert", "IBertForSequenceClassification"),
  928. ("jamba", "JambaForSequenceClassification"),
  929. ("jetmoe", "JetMoeForSequenceClassification"),
  930. ("layoutlm", "LayoutLMForSequenceClassification"),
  931. ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
  932. ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
  933. ("led", "LEDForSequenceClassification"),
  934. ("lilt", "LiltForSequenceClassification"),
  935. ("llama", "LlamaForSequenceClassification"),
  936. ("longformer", "LongformerForSequenceClassification"),
  937. ("luke", "LukeForSequenceClassification"),
  938. ("markuplm", "MarkupLMForSequenceClassification"),
  939. ("mbart", "MBartForSequenceClassification"),
  940. ("mega", "MegaForSequenceClassification"),
  941. ("megatron-bert", "MegatronBertForSequenceClassification"),
  942. ("mistral", "MistralForSequenceClassification"),
  943. ("mixtral", "MixtralForSequenceClassification"),
  944. ("mobilebert", "MobileBertForSequenceClassification"),
  945. ("mpnet", "MPNetForSequenceClassification"),
  946. ("mpt", "MptForSequenceClassification"),
  947. ("mra", "MraForSequenceClassification"),
  948. ("mt5", "MT5ForSequenceClassification"),
  949. ("mvp", "MvpForSequenceClassification"),
  950. ("nemotron", "NemotronForSequenceClassification"),
  951. ("nezha", "NezhaForSequenceClassification"),
  952. ("nystromformer", "NystromformerForSequenceClassification"),
  953. ("open-llama", "OpenLlamaForSequenceClassification"),
  954. ("openai-gpt", "OpenAIGPTForSequenceClassification"),
  955. ("opt", "OPTForSequenceClassification"),
  956. ("perceiver", "PerceiverForSequenceClassification"),
  957. ("persimmon", "PersimmonForSequenceClassification"),
  958. ("phi", "PhiForSequenceClassification"),
  959. ("phi3", "Phi3ForSequenceClassification"),
  960. ("phimoe", "PhimoeForSequenceClassification"),
  961. ("plbart", "PLBartForSequenceClassification"),
  962. ("qdqbert", "QDQBertForSequenceClassification"),
  963. ("qwen2", "Qwen2ForSequenceClassification"),
  964. ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
  965. ("reformer", "ReformerForSequenceClassification"),
  966. ("rembert", "RemBertForSequenceClassification"),
  967. ("roberta", "RobertaForSequenceClassification"),
  968. ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
  969. ("roc_bert", "RoCBertForSequenceClassification"),
  970. ("roformer", "RoFormerForSequenceClassification"),
  971. ("squeezebert", "SqueezeBertForSequenceClassification"),
  972. ("stablelm", "StableLmForSequenceClassification"),
  973. ("starcoder2", "Starcoder2ForSequenceClassification"),
  974. ("t5", "T5ForSequenceClassification"),
  975. ("tapas", "TapasForSequenceClassification"),
  976. ("transfo-xl", "TransfoXLForSequenceClassification"),
  977. ("umt5", "UMT5ForSequenceClassification"),
  978. ("xlm", "XLMForSequenceClassification"),
  979. ("xlm-roberta", "XLMRobertaForSequenceClassification"),
  980. ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
  981. ("xlnet", "XLNetForSequenceClassification"),
  982. ("xmod", "XmodForSequenceClassification"),
  983. ("yoso", "YosoForSequenceClassification"),
  984. ("zamba", "ZambaForSequenceClassification"),
  985. ]
  986. )
  987. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  988. [
  989. # Model for Question Answering mapping
  990. ("albert", "AlbertForQuestionAnswering"),
  991. ("bart", "BartForQuestionAnswering"),
  992. ("bert", "BertForQuestionAnswering"),
  993. ("big_bird", "BigBirdForQuestionAnswering"),
  994. ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
  995. ("bloom", "BloomForQuestionAnswering"),
  996. ("camembert", "CamembertForQuestionAnswering"),
  997. ("canine", "CanineForQuestionAnswering"),
  998. ("convbert", "ConvBertForQuestionAnswering"),
  999. ("data2vec-text", "Data2VecTextForQuestionAnswering"),
  1000. ("deberta", "DebertaForQuestionAnswering"),
  1001. ("deberta-v2", "DebertaV2ForQuestionAnswering"),
  1002. ("distilbert", "DistilBertForQuestionAnswering"),
  1003. ("electra", "ElectraForQuestionAnswering"),
  1004. ("ernie", "ErnieForQuestionAnswering"),
  1005. ("ernie_m", "ErnieMForQuestionAnswering"),
  1006. ("falcon", "FalconForQuestionAnswering"),
  1007. ("flaubert", "FlaubertForQuestionAnsweringSimple"),
  1008. ("fnet", "FNetForQuestionAnswering"),
  1009. ("funnel", "FunnelForQuestionAnswering"),
  1010. ("gpt2", "GPT2ForQuestionAnswering"),
  1011. ("gpt_neo", "GPTNeoForQuestionAnswering"),
  1012. ("gpt_neox", "GPTNeoXForQuestionAnswering"),
  1013. ("gptj", "GPTJForQuestionAnswering"),
  1014. ("ibert", "IBertForQuestionAnswering"),
  1015. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1016. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1017. ("led", "LEDForQuestionAnswering"),
  1018. ("lilt", "LiltForQuestionAnswering"),
  1019. ("llama", "LlamaForQuestionAnswering"),
  1020. ("longformer", "LongformerForQuestionAnswering"),
  1021. ("luke", "LukeForQuestionAnswering"),
  1022. ("lxmert", "LxmertForQuestionAnswering"),
  1023. ("markuplm", "MarkupLMForQuestionAnswering"),
  1024. ("mbart", "MBartForQuestionAnswering"),
  1025. ("mega", "MegaForQuestionAnswering"),
  1026. ("megatron-bert", "MegatronBertForQuestionAnswering"),
  1027. ("mistral", "MistralForQuestionAnswering"),
  1028. ("mixtral", "MixtralForQuestionAnswering"),
  1029. ("mobilebert", "MobileBertForQuestionAnswering"),
  1030. ("mpnet", "MPNetForQuestionAnswering"),
  1031. ("mpt", "MptForQuestionAnswering"),
  1032. ("mra", "MraForQuestionAnswering"),
  1033. ("mt5", "MT5ForQuestionAnswering"),
  1034. ("mvp", "MvpForQuestionAnswering"),
  1035. ("nemotron", "NemotronForQuestionAnswering"),
  1036. ("nezha", "NezhaForQuestionAnswering"),
  1037. ("nystromformer", "NystromformerForQuestionAnswering"),
  1038. ("opt", "OPTForQuestionAnswering"),
  1039. ("qdqbert", "QDQBertForQuestionAnswering"),
  1040. ("qwen2", "Qwen2ForQuestionAnswering"),
  1041. ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
  1042. ("reformer", "ReformerForQuestionAnswering"),
  1043. ("rembert", "RemBertForQuestionAnswering"),
  1044. ("roberta", "RobertaForQuestionAnswering"),
  1045. ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
  1046. ("roc_bert", "RoCBertForQuestionAnswering"),
  1047. ("roformer", "RoFormerForQuestionAnswering"),
  1048. ("splinter", "SplinterForQuestionAnswering"),
  1049. ("squeezebert", "SqueezeBertForQuestionAnswering"),
  1050. ("t5", "T5ForQuestionAnswering"),
  1051. ("umt5", "UMT5ForQuestionAnswering"),
  1052. ("xlm", "XLMForQuestionAnsweringSimple"),
  1053. ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
  1054. ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
  1055. ("xlnet", "XLNetForQuestionAnsweringSimple"),
  1056. ("xmod", "XmodForQuestionAnswering"),
  1057. ("yoso", "YosoForQuestionAnswering"),
  1058. ]
  1059. )
  1060. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1061. [
  1062. # Model for Table Question Answering mapping
  1063. ("tapas", "TapasForQuestionAnswering"),
  1064. ]
  1065. )
  1066. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1067. [
  1068. ("blip", "BlipForQuestionAnswering"),
  1069. ("blip-2", "Blip2ForConditionalGeneration"),
  1070. ("vilt", "ViltForQuestionAnswering"),
  1071. ]
  1072. )
  1073. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  1074. [
  1075. ("layoutlm", "LayoutLMForQuestionAnswering"),
  1076. ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
  1077. ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
  1078. ]
  1079. )
  1080. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1081. [
  1082. # Model for Token Classification mapping
  1083. ("albert", "AlbertForTokenClassification"),
  1084. ("bert", "BertForTokenClassification"),
  1085. ("big_bird", "BigBirdForTokenClassification"),
  1086. ("biogpt", "BioGptForTokenClassification"),
  1087. ("bloom", "BloomForTokenClassification"),
  1088. ("bros", "BrosForTokenClassification"),
  1089. ("camembert", "CamembertForTokenClassification"),
  1090. ("canine", "CanineForTokenClassification"),
  1091. ("convbert", "ConvBertForTokenClassification"),
  1092. ("data2vec-text", "Data2VecTextForTokenClassification"),
  1093. ("deberta", "DebertaForTokenClassification"),
  1094. ("deberta-v2", "DebertaV2ForTokenClassification"),
  1095. ("distilbert", "DistilBertForTokenClassification"),
  1096. ("electra", "ElectraForTokenClassification"),
  1097. ("ernie", "ErnieForTokenClassification"),
  1098. ("ernie_m", "ErnieMForTokenClassification"),
  1099. ("esm", "EsmForTokenClassification"),
  1100. ("falcon", "FalconForTokenClassification"),
  1101. ("flaubert", "FlaubertForTokenClassification"),
  1102. ("fnet", "FNetForTokenClassification"),
  1103. ("funnel", "FunnelForTokenClassification"),
  1104. ("gemma", "GemmaForTokenClassification"),
  1105. ("gemma2", "Gemma2ForTokenClassification"),
  1106. ("glm", "GlmForTokenClassification"),
  1107. ("gpt-sw3", "GPT2ForTokenClassification"),
  1108. ("gpt2", "GPT2ForTokenClassification"),
  1109. ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
  1110. ("gpt_neo", "GPTNeoForTokenClassification"),
  1111. ("gpt_neox", "GPTNeoXForTokenClassification"),
  1112. ("ibert", "IBertForTokenClassification"),
  1113. ("layoutlm", "LayoutLMForTokenClassification"),
  1114. ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
  1115. ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
  1116. ("lilt", "LiltForTokenClassification"),
  1117. ("llama", "LlamaForTokenClassification"),
  1118. ("longformer", "LongformerForTokenClassification"),
  1119. ("luke", "LukeForTokenClassification"),
  1120. ("markuplm", "MarkupLMForTokenClassification"),
  1121. ("mega", "MegaForTokenClassification"),
  1122. ("megatron-bert", "MegatronBertForTokenClassification"),
  1123. ("mistral", "MistralForTokenClassification"),
  1124. ("mixtral", "MixtralForTokenClassification"),
  1125. ("mobilebert", "MobileBertForTokenClassification"),
  1126. ("mpnet", "MPNetForTokenClassification"),
  1127. ("mpt", "MptForTokenClassification"),
  1128. ("mra", "MraForTokenClassification"),
  1129. ("mt5", "MT5ForTokenClassification"),
  1130. ("nemotron", "NemotronForTokenClassification"),
  1131. ("nezha", "NezhaForTokenClassification"),
  1132. ("nystromformer", "NystromformerForTokenClassification"),
  1133. ("persimmon", "PersimmonForTokenClassification"),
  1134. ("phi", "PhiForTokenClassification"),
  1135. ("phi3", "Phi3ForTokenClassification"),
  1136. ("qdqbert", "QDQBertForTokenClassification"),
  1137. ("qwen2", "Qwen2ForTokenClassification"),
  1138. ("qwen2_moe", "Qwen2MoeForTokenClassification"),
  1139. ("rembert", "RemBertForTokenClassification"),
  1140. ("roberta", "RobertaForTokenClassification"),
  1141. ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
  1142. ("roc_bert", "RoCBertForTokenClassification"),
  1143. ("roformer", "RoFormerForTokenClassification"),
  1144. ("squeezebert", "SqueezeBertForTokenClassification"),
  1145. ("stablelm", "StableLmForTokenClassification"),
  1146. ("starcoder2", "Starcoder2ForTokenClassification"),
  1147. ("t5", "T5ForTokenClassification"),
  1148. ("umt5", "UMT5ForTokenClassification"),
  1149. ("xlm", "XLMForTokenClassification"),
  1150. ("xlm-roberta", "XLMRobertaForTokenClassification"),
  1151. ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
  1152. ("xlnet", "XLNetForTokenClassification"),
  1153. ("xmod", "XmodForTokenClassification"),
  1154. ("yoso", "YosoForTokenClassification"),
  1155. ]
  1156. )
  1157. MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
  1158. [
  1159. # Model for Multiple Choice mapping
  1160. ("albert", "AlbertForMultipleChoice"),
  1161. ("bert", "BertForMultipleChoice"),
  1162. ("big_bird", "BigBirdForMultipleChoice"),
  1163. ("camembert", "CamembertForMultipleChoice"),
  1164. ("canine", "CanineForMultipleChoice"),
  1165. ("convbert", "ConvBertForMultipleChoice"),
  1166. ("data2vec-text", "Data2VecTextForMultipleChoice"),
  1167. ("deberta-v2", "DebertaV2ForMultipleChoice"),
  1168. ("distilbert", "DistilBertForMultipleChoice"),
  1169. ("electra", "ElectraForMultipleChoice"),
  1170. ("ernie", "ErnieForMultipleChoice"),
  1171. ("ernie_m", "ErnieMForMultipleChoice"),
  1172. ("flaubert", "FlaubertForMultipleChoice"),
  1173. ("fnet", "FNetForMultipleChoice"),
  1174. ("funnel", "FunnelForMultipleChoice"),
  1175. ("ibert", "IBertForMultipleChoice"),
  1176. ("longformer", "LongformerForMultipleChoice"),
  1177. ("luke", "LukeForMultipleChoice"),
  1178. ("mega", "MegaForMultipleChoice"),
  1179. ("megatron-bert", "MegatronBertForMultipleChoice"),
  1180. ("mobilebert", "MobileBertForMultipleChoice"),
  1181. ("mpnet", "MPNetForMultipleChoice"),
  1182. ("mra", "MraForMultipleChoice"),
  1183. ("nezha", "NezhaForMultipleChoice"),
  1184. ("nystromformer", "NystromformerForMultipleChoice"),
  1185. ("qdqbert", "QDQBertForMultipleChoice"),
  1186. ("rembert", "RemBertForMultipleChoice"),
  1187. ("roberta", "RobertaForMultipleChoice"),
  1188. ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
  1189. ("roc_bert", "RoCBertForMultipleChoice"),
  1190. ("roformer", "RoFormerForMultipleChoice"),
  1191. ("squeezebert", "SqueezeBertForMultipleChoice"),
  1192. ("xlm", "XLMForMultipleChoice"),
  1193. ("xlm-roberta", "XLMRobertaForMultipleChoice"),
  1194. ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
  1195. ("xlnet", "XLNetForMultipleChoice"),
  1196. ("xmod", "XmodForMultipleChoice"),
  1197. ("yoso", "YosoForMultipleChoice"),
  1198. ]
  1199. )
  1200. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
  1201. [
  1202. ("bert", "BertForNextSentencePrediction"),
  1203. ("ernie", "ErnieForNextSentencePrediction"),
  1204. ("fnet", "FNetForNextSentencePrediction"),
  1205. ("megatron-bert", "MegatronBertForNextSentencePrediction"),
  1206. ("mobilebert", "MobileBertForNextSentencePrediction"),
  1207. ("nezha", "NezhaForNextSentencePrediction"),
  1208. ("qdqbert", "QDQBertForNextSentencePrediction"),
  1209. ]
  1210. )
  1211. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1212. [
  1213. # Model for Audio Classification mapping
  1214. ("audio-spectrogram-transformer", "ASTForAudioClassification"),
  1215. ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
  1216. ("hubert", "HubertForSequenceClassification"),
  1217. ("sew", "SEWForSequenceClassification"),
  1218. ("sew-d", "SEWDForSequenceClassification"),
  1219. ("unispeech", "UniSpeechForSequenceClassification"),
  1220. ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
  1221. ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
  1222. ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
  1223. ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
  1224. ("wavlm", "WavLMForSequenceClassification"),
  1225. ("whisper", "WhisperForAudioClassification"),
  1226. ]
  1227. )
  1228. MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
  1229. [
  1230. # Model for Connectionist temporal classification (CTC) mapping
  1231. ("data2vec-audio", "Data2VecAudioForCTC"),
  1232. ("hubert", "HubertForCTC"),
  1233. ("mctct", "MCTCTForCTC"),
  1234. ("sew", "SEWForCTC"),
  1235. ("sew-d", "SEWDForCTC"),
  1236. ("unispeech", "UniSpeechForCTC"),
  1237. ("unispeech-sat", "UniSpeechSatForCTC"),
  1238. ("wav2vec2", "Wav2Vec2ForCTC"),
  1239. ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
  1240. ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
  1241. ("wavlm", "WavLMForCTC"),
  1242. ]
  1243. )
  1244. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1245. [
  1246. # Model for Audio Classification mapping
  1247. ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
  1248. ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
  1249. ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
  1250. ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
  1251. ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
  1252. ("wavlm", "WavLMForAudioFrameClassification"),
  1253. ]
  1254. )
  1255. MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
  1256. [
  1257. # Model for Audio Classification mapping
  1258. ("data2vec-audio", "Data2VecAudioForXVector"),
  1259. ("unispeech-sat", "UniSpeechSatForXVector"),
  1260. ("wav2vec2", "Wav2Vec2ForXVector"),
  1261. ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
  1262. ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
  1263. ("wavlm", "WavLMForXVector"),
  1264. ]
  1265. )
  1266. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
  1267. [
  1268. # Model for Text-To-Spectrogram mapping
  1269. ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
  1270. ("speecht5", "SpeechT5ForTextToSpeech"),
  1271. ]
  1272. )
  1273. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
  1274. [
  1275. # Model for Text-To-Waveform mapping
  1276. ("bark", "BarkModel"),
  1277. ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
  1278. ("musicgen", "MusicgenForConditionalGeneration"),
  1279. ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
  1280. ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
  1281. ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
  1282. ("vits", "VitsModel"),
  1283. ]
  1284. )
  1285. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1286. [
  1287. # Model for Zero Shot Image Classification mapping
  1288. ("align", "AlignModel"),
  1289. ("altclip", "AltCLIPModel"),
  1290. ("blip", "BlipModel"),
  1291. ("blip-2", "Blip2ForImageTextRetrieval"),
  1292. ("chinese_clip", "ChineseCLIPModel"),
  1293. ("clip", "CLIPModel"),
  1294. ("clipseg", "CLIPSegModel"),
  1295. ("siglip", "SiglipModel"),
  1296. ]
  1297. )
  1298. MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
  1299. [
  1300. # Backbone mapping
  1301. ("beit", "BeitBackbone"),
  1302. ("bit", "BitBackbone"),
  1303. ("convnext", "ConvNextBackbone"),
  1304. ("convnextv2", "ConvNextV2Backbone"),
  1305. ("dinat", "DinatBackbone"),
  1306. ("dinov2", "Dinov2Backbone"),
  1307. ("focalnet", "FocalNetBackbone"),
  1308. ("hiera", "HieraBackbone"),
  1309. ("maskformer-swin", "MaskFormerSwinBackbone"),
  1310. ("nat", "NatBackbone"),
  1311. ("pvt_v2", "PvtV2Backbone"),
  1312. ("resnet", "ResNetBackbone"),
  1313. ("rt_detr_resnet", "RTDetrResNetBackbone"),
  1314. ("swin", "SwinBackbone"),
  1315. ("swinv2", "Swinv2Backbone"),
  1316. ("timm_backbone", "TimmBackbone"),
  1317. ("vitdet", "VitDetBackbone"),
  1318. ]
  1319. )
  1320. MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
  1321. [
  1322. ("sam", "SamModel"),
  1323. ]
  1324. )
  1325. MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
  1326. [
  1327. ("superpoint", "SuperPointForKeypointDetection"),
  1328. ]
  1329. )
  1330. MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
  1331. [
  1332. ("albert", "AlbertModel"),
  1333. ("bert", "BertModel"),
  1334. ("big_bird", "BigBirdModel"),
  1335. ("clip_text_model", "CLIPTextModel"),
  1336. ("data2vec-text", "Data2VecTextModel"),
  1337. ("deberta", "DebertaModel"),
  1338. ("deberta-v2", "DebertaV2Model"),
  1339. ("distilbert", "DistilBertModel"),
  1340. ("electra", "ElectraModel"),
  1341. ("flaubert", "FlaubertModel"),
  1342. ("ibert", "IBertModel"),
  1343. ("longformer", "LongformerModel"),
  1344. ("mllama", "MllamaTextModel"),
  1345. ("mobilebert", "MobileBertModel"),
  1346. ("mt5", "MT5EncoderModel"),
  1347. ("nystromformer", "NystromformerModel"),
  1348. ("reformer", "ReformerModel"),
  1349. ("rembert", "RemBertModel"),
  1350. ("roberta", "RobertaModel"),
  1351. ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
  1352. ("roc_bert", "RoCBertModel"),
  1353. ("roformer", "RoFormerModel"),
  1354. ("squeezebert", "SqueezeBertModel"),
  1355. ("t5", "T5EncoderModel"),
  1356. ("umt5", "UMT5EncoderModel"),
  1357. ("xlm", "XLMModel"),
  1358. ("xlm-roberta", "XLMRobertaModel"),
  1359. ("xlm-roberta-xl", "XLMRobertaXLModel"),
  1360. ]
  1361. )
  1362. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  1363. [
  1364. ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
  1365. ("patchtst", "PatchTSTForClassification"),
  1366. ]
  1367. )
  1368. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
  1369. [
  1370. ("patchtsmixer", "PatchTSMixerForRegression"),
  1371. ("patchtst", "PatchTSTForRegression"),
  1372. ]
  1373. )
  1374. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
  1375. [
  1376. ("swin2sr", "Swin2SRForImageSuperResolution"),
  1377. ]
  1378. )
  1379. MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
  1380. MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
  1381. MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
  1382. MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  1383. MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1384. CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
  1385. )
  1386. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1387. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1388. )
  1389. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1390. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
  1391. )
  1392. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1393. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
  1394. )
  1395. MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1396. CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
  1397. )
  1398. MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1399. CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
  1400. )
  1401. MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
  1402. CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
  1403. )
  1404. MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1405. CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
  1406. )
  1407. MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
  1408. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
  1409. CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
  1410. )
  1411. MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1412. CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
  1413. )
  1414. MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1415. CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
  1416. )
  1417. MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
  1418. MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
  1419. MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  1420. CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
  1421. )
  1422. MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
  1423. MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
  1424. CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
  1425. )
  1426. MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
  1427. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
  1428. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  1429. )
  1430. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1431. CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
  1432. )
  1433. MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1434. CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
  1435. )
  1436. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  1437. CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
  1438. )
  1439. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1440. CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  1441. )
  1442. MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
  1443. MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
  1444. CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
  1445. )
  1446. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1447. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
  1448. )
  1449. MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
  1450. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
  1451. MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1452. CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
  1453. )
  1454. MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
  1455. MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
  1456. CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
  1457. )
  1458. MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
  1459. MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
  1460. MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
  1461. MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
  1462. CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
  1463. )
  1464. MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
  1465. MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  1466. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
  1467. )
  1468. MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
  1469. CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
  1470. )
  1471. MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
  1472. class AutoModelForMaskGeneration(_BaseAutoModelClass):
  1473. _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
  1474. class AutoModelForKeypointDetection(_BaseAutoModelClass):
  1475. _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
  1476. class AutoModelForTextEncoding(_BaseAutoModelClass):
  1477. _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
  1478. class AutoModelForImageToImage(_BaseAutoModelClass):
  1479. _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
  1480. class AutoModel(_BaseAutoModelClass):
  1481. _model_mapping = MODEL_MAPPING
  1482. AutoModel = auto_class_update(AutoModel)
  1483. class AutoModelForPreTraining(_BaseAutoModelClass):
  1484. _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
  1485. AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
  1486. # Private on purpose, the public class will add the deprecation warnings.
  1487. class _AutoModelWithLMHead(_BaseAutoModelClass):
  1488. _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
  1489. _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
  1490. class AutoModelForCausalLM(_BaseAutoModelClass):
  1491. _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
  1492. AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
  1493. class AutoModelForMaskedLM(_BaseAutoModelClass):
  1494. _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
  1495. AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
  1496. class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
  1497. _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
  1498. AutoModelForSeq2SeqLM = auto_class_update(
  1499. AutoModelForSeq2SeqLM,
  1500. head_doc="sequence-to-sequence language modeling",
  1501. checkpoint_for_example="google-t5/t5-base",
  1502. )
  1503. class AutoModelForSequenceClassification(_BaseAutoModelClass):
  1504. _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
  1505. AutoModelForSequenceClassification = auto_class_update(
  1506. AutoModelForSequenceClassification, head_doc="sequence classification"
  1507. )
  1508. class AutoModelForQuestionAnswering(_BaseAutoModelClass):
  1509. _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
  1510. AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
  1511. class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
  1512. _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
  1513. AutoModelForTableQuestionAnswering = auto_class_update(
  1514. AutoModelForTableQuestionAnswering,
  1515. head_doc="table question answering",
  1516. checkpoint_for_example="google/tapas-base-finetuned-wtq",
  1517. )
  1518. class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
  1519. _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
  1520. AutoModelForVisualQuestionAnswering = auto_class_update(
  1521. AutoModelForVisualQuestionAnswering,
  1522. head_doc="visual question answering",
  1523. checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
  1524. )
  1525. class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
  1526. _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
  1527. AutoModelForDocumentQuestionAnswering = auto_class_update(
  1528. AutoModelForDocumentQuestionAnswering,
  1529. head_doc="document question answering",
  1530. checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
  1531. )
  1532. class AutoModelForTokenClassification(_BaseAutoModelClass):
  1533. _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
  1534. AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
  1535. class AutoModelForMultipleChoice(_BaseAutoModelClass):
  1536. _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
  1537. AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
  1538. class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
  1539. _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
  1540. AutoModelForNextSentencePrediction = auto_class_update(
  1541. AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
  1542. )
  1543. class AutoModelForImageClassification(_BaseAutoModelClass):
  1544. _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
  1545. AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
  1546. class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
  1547. _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
  1548. AutoModelForZeroShotImageClassification = auto_class_update(
  1549. AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
  1550. )
  1551. class AutoModelForImageSegmentation(_BaseAutoModelClass):
  1552. _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
  1553. AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
  1554. class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
  1555. _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
  1556. AutoModelForSemanticSegmentation = auto_class_update(
  1557. AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
  1558. )
  1559. class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
  1560. _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
  1561. AutoModelForUniversalSegmentation = auto_class_update(
  1562. AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
  1563. )
  1564. class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
  1565. _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
  1566. AutoModelForInstanceSegmentation = auto_class_update(
  1567. AutoModelForInstanceSegmentation, head_doc="instance segmentation"
  1568. )
  1569. class AutoModelForObjectDetection(_BaseAutoModelClass):
  1570. _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
  1571. AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
  1572. class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
  1573. _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
  1574. AutoModelForZeroShotObjectDetection = auto_class_update(
  1575. AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
  1576. )
  1577. class AutoModelForDepthEstimation(_BaseAutoModelClass):
  1578. _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
  1579. AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
  1580. class AutoModelForVideoClassification(_BaseAutoModelClass):
  1581. _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
  1582. AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
  1583. class AutoModelForVision2Seq(_BaseAutoModelClass):
  1584. _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
  1585. AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
  1586. class AutoModelForImageTextToText(_BaseAutoModelClass):
  1587. _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
  1588. AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
  1589. class AutoModelForAudioClassification(_BaseAutoModelClass):
  1590. _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
  1591. AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
  1592. class AutoModelForCTC(_BaseAutoModelClass):
  1593. _model_mapping = MODEL_FOR_CTC_MAPPING
  1594. AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
  1595. class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
  1596. _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
  1597. AutoModelForSpeechSeq2Seq = auto_class_update(
  1598. AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
  1599. )
  1600. class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
  1601. _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
  1602. AutoModelForAudioFrameClassification = auto_class_update(
  1603. AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
  1604. )
  1605. class AutoModelForAudioXVector(_BaseAutoModelClass):
  1606. _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
  1607. class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
  1608. _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
  1609. class AutoModelForTextToWaveform(_BaseAutoModelClass):
  1610. _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
  1611. class AutoBackbone(_BaseAutoBackboneClass):
  1612. _model_mapping = MODEL_FOR_BACKBONE_MAPPING
  1613. AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
  1614. class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
  1615. _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
  1616. AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
  1617. class AutoModelWithLMHead(_AutoModelWithLMHead):
  1618. @classmethod
  1619. def from_config(cls, config):
  1620. warnings.warn(
  1621. "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
  1622. "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
  1623. "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
  1624. FutureWarning,
  1625. )
  1626. return super().from_config(config)
  1627. @classmethod
  1628. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  1629. warnings.warn(
  1630. "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
  1631. "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
  1632. "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
  1633. FutureWarning,
  1634. )
  1635. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)