| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839 |
- # coding=utf-8
- # Copyright 2018 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Auto Model class."""
- import warnings
- from collections import OrderedDict
- from ...utils import logging
- from .auto_factory import (
- _BaseAutoBackboneClass,
- _BaseAutoModelClass,
- _LazyAutoMapping,
- auto_class_update,
- )
- from .configuration_auto import CONFIG_MAPPING_NAMES
- logger = logging.get_logger(__name__)
- MODEL_MAPPING_NAMES = OrderedDict(
- [
- # Base model mapping
- ("albert", "AlbertModel"),
- ("align", "AlignModel"),
- ("altclip", "AltCLIPModel"),
- ("audio-spectrogram-transformer", "ASTModel"),
- ("autoformer", "AutoformerModel"),
- ("bark", "BarkModel"),
- ("bart", "BartModel"),
- ("beit", "BeitModel"),
- ("bert", "BertModel"),
- ("bert-generation", "BertGenerationEncoder"),
- ("big_bird", "BigBirdModel"),
- ("bigbird_pegasus", "BigBirdPegasusModel"),
- ("biogpt", "BioGptModel"),
- ("bit", "BitModel"),
- ("blenderbot", "BlenderbotModel"),
- ("blenderbot-small", "BlenderbotSmallModel"),
- ("blip", "BlipModel"),
- ("blip-2", "Blip2Model"),
- ("bloom", "BloomModel"),
- ("bridgetower", "BridgeTowerModel"),
- ("bros", "BrosModel"),
- ("camembert", "CamembertModel"),
- ("canine", "CanineModel"),
- ("chameleon", "ChameleonModel"),
- ("chinese_clip", "ChineseCLIPModel"),
- ("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
- ("clap", "ClapModel"),
- ("clip", "CLIPModel"),
- ("clip_text_model", "CLIPTextModel"),
- ("clip_vision_model", "CLIPVisionModel"),
- ("clipseg", "CLIPSegModel"),
- ("clvp", "ClvpModelForConditionalGeneration"),
- ("code_llama", "LlamaModel"),
- ("codegen", "CodeGenModel"),
- ("cohere", "CohereModel"),
- ("conditional_detr", "ConditionalDetrModel"),
- ("convbert", "ConvBertModel"),
- ("convnext", "ConvNextModel"),
- ("convnextv2", "ConvNextV2Model"),
- ("cpmant", "CpmAntModel"),
- ("ctrl", "CTRLModel"),
- ("cvt", "CvtModel"),
- ("dac", "DacModel"),
- ("data2vec-audio", "Data2VecAudioModel"),
- ("data2vec-text", "Data2VecTextModel"),
- ("data2vec-vision", "Data2VecVisionModel"),
- ("dbrx", "DbrxModel"),
- ("deberta", "DebertaModel"),
- ("deberta-v2", "DebertaV2Model"),
- ("decision_transformer", "DecisionTransformerModel"),
- ("deformable_detr", "DeformableDetrModel"),
- ("deit", "DeiTModel"),
- ("deta", "DetaModel"),
- ("detr", "DetrModel"),
- ("dinat", "DinatModel"),
- ("dinov2", "Dinov2Model"),
- ("distilbert", "DistilBertModel"),
- ("donut-swin", "DonutSwinModel"),
- ("dpr", "DPRQuestionEncoder"),
- ("dpt", "DPTModel"),
- ("efficientformer", "EfficientFormerModel"),
- ("efficientnet", "EfficientNetModel"),
- ("electra", "ElectraModel"),
- ("encodec", "EncodecModel"),
- ("ernie", "ErnieModel"),
- ("ernie_m", "ErnieMModel"),
- ("esm", "EsmModel"),
- ("falcon", "FalconModel"),
- ("falcon_mamba", "FalconMambaModel"),
- ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
- ("flaubert", "FlaubertModel"),
- ("flava", "FlavaModel"),
- ("fnet", "FNetModel"),
- ("focalnet", "FocalNetModel"),
- ("fsmt", "FSMTModel"),
- ("funnel", ("FunnelModel", "FunnelBaseModel")),
- ("gemma", "GemmaModel"),
- ("gemma2", "Gemma2Model"),
- ("git", "GitModel"),
- ("glm", "GlmModel"),
- ("glpn", "GLPNModel"),
- ("gpt-sw3", "GPT2Model"),
- ("gpt2", "GPT2Model"),
- ("gpt_bigcode", "GPTBigCodeModel"),
- ("gpt_neo", "GPTNeoModel"),
- ("gpt_neox", "GPTNeoXModel"),
- ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
- ("gptj", "GPTJModel"),
- ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
- ("granite", "GraniteModel"),
- ("granitemoe", "GraniteMoeModel"),
- ("graphormer", "GraphormerModel"),
- ("grounding-dino", "GroundingDinoModel"),
- ("groupvit", "GroupViTModel"),
- ("hiera", "HieraModel"),
- ("hubert", "HubertModel"),
- ("ibert", "IBertModel"),
- ("idefics", "IdeficsModel"),
- ("idefics2", "Idefics2Model"),
- ("idefics3", "Idefics3Model"),
- ("imagegpt", "ImageGPTModel"),
- ("informer", "InformerModel"),
- ("jamba", "JambaModel"),
- ("jetmoe", "JetMoeModel"),
- ("jukebox", "JukeboxModel"),
- ("kosmos-2", "Kosmos2Model"),
- ("layoutlm", "LayoutLMModel"),
- ("layoutlmv2", "LayoutLMv2Model"),
- ("layoutlmv3", "LayoutLMv3Model"),
- ("led", "LEDModel"),
- ("levit", "LevitModel"),
- ("lilt", "LiltModel"),
- ("llama", "LlamaModel"),
- ("longformer", "LongformerModel"),
- ("longt5", "LongT5Model"),
- ("luke", "LukeModel"),
- ("lxmert", "LxmertModel"),
- ("m2m_100", "M2M100Model"),
- ("mamba", "MambaModel"),
- ("mamba2", "Mamba2Model"),
- ("marian", "MarianModel"),
- ("markuplm", "MarkupLMModel"),
- ("mask2former", "Mask2FormerModel"),
- ("maskformer", "MaskFormerModel"),
- ("maskformer-swin", "MaskFormerSwinModel"),
- ("mbart", "MBartModel"),
- ("mctct", "MCTCTModel"),
- ("mega", "MegaModel"),
- ("megatron-bert", "MegatronBertModel"),
- ("mgp-str", "MgpstrForSceneTextRecognition"),
- ("mimi", "MimiModel"),
- ("mistral", "MistralModel"),
- ("mixtral", "MixtralModel"),
- ("mobilebert", "MobileBertModel"),
- ("mobilenet_v1", "MobileNetV1Model"),
- ("mobilenet_v2", "MobileNetV2Model"),
- ("mobilevit", "MobileViTModel"),
- ("mobilevitv2", "MobileViTV2Model"),
- ("moshi", "MoshiModel"),
- ("mpnet", "MPNetModel"),
- ("mpt", "MptModel"),
- ("mra", "MraModel"),
- ("mt5", "MT5Model"),
- ("musicgen", "MusicgenModel"),
- ("musicgen_melody", "MusicgenMelodyModel"),
- ("mvp", "MvpModel"),
- ("nat", "NatModel"),
- ("nemotron", "NemotronModel"),
- ("nezha", "NezhaModel"),
- ("nllb-moe", "NllbMoeModel"),
- ("nystromformer", "NystromformerModel"),
- ("olmo", "OlmoModel"),
- ("olmoe", "OlmoeModel"),
- ("omdet-turbo", "OmDetTurboForObjectDetection"),
- ("oneformer", "OneFormerModel"),
- ("open-llama", "OpenLlamaModel"),
- ("openai-gpt", "OpenAIGPTModel"),
- ("opt", "OPTModel"),
- ("owlv2", "Owlv2Model"),
- ("owlvit", "OwlViTModel"),
- ("patchtsmixer", "PatchTSMixerModel"),
- ("patchtst", "PatchTSTModel"),
- ("pegasus", "PegasusModel"),
- ("pegasus_x", "PegasusXModel"),
- ("perceiver", "PerceiverModel"),
- ("persimmon", "PersimmonModel"),
- ("phi", "PhiModel"),
- ("phi3", "Phi3Model"),
- ("phimoe", "PhimoeModel"),
- ("pixtral", "PixtralVisionModel"),
- ("plbart", "PLBartModel"),
- ("poolformer", "PoolFormerModel"),
- ("prophetnet", "ProphetNetModel"),
- ("pvt", "PvtModel"),
- ("pvt_v2", "PvtV2Model"),
- ("qdqbert", "QDQBertModel"),
- ("qwen2", "Qwen2Model"),
- ("qwen2_audio_encoder", "Qwen2AudioEncoder"),
- ("qwen2_moe", "Qwen2MoeModel"),
- ("qwen2_vl", "Qwen2VLModel"),
- ("recurrent_gemma", "RecurrentGemmaModel"),
- ("reformer", "ReformerModel"),
- ("regnet", "RegNetModel"),
- ("rembert", "RemBertModel"),
- ("resnet", "ResNetModel"),
- ("retribert", "RetriBertModel"),
- ("roberta", "RobertaModel"),
- ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
- ("roc_bert", "RoCBertModel"),
- ("roformer", "RoFormerModel"),
- ("rt_detr", "RTDetrModel"),
- ("rwkv", "RwkvModel"),
- ("sam", "SamModel"),
- ("seamless_m4t", "SeamlessM4TModel"),
- ("seamless_m4t_v2", "SeamlessM4Tv2Model"),
- ("segformer", "SegformerModel"),
- ("seggpt", "SegGptModel"),
- ("sew", "SEWModel"),
- ("sew-d", "SEWDModel"),
- ("siglip", "SiglipModel"),
- ("siglip_vision_model", "SiglipVisionModel"),
- ("speech_to_text", "Speech2TextModel"),
- ("speecht5", "SpeechT5Model"),
- ("splinter", "SplinterModel"),
- ("squeezebert", "SqueezeBertModel"),
- ("stablelm", "StableLmModel"),
- ("starcoder2", "Starcoder2Model"),
- ("swiftformer", "SwiftFormerModel"),
- ("swin", "SwinModel"),
- ("swin2sr", "Swin2SRModel"),
- ("swinv2", "Swinv2Model"),
- ("switch_transformers", "SwitchTransformersModel"),
- ("t5", "T5Model"),
- ("table-transformer", "TableTransformerModel"),
- ("tapas", "TapasModel"),
- ("time_series_transformer", "TimeSeriesTransformerModel"),
- ("timesformer", "TimesformerModel"),
- ("timm_backbone", "TimmBackbone"),
- ("trajectory_transformer", "TrajectoryTransformerModel"),
- ("transfo-xl", "TransfoXLModel"),
- ("tvlt", "TvltModel"),
- ("tvp", "TvpModel"),
- ("udop", "UdopModel"),
- ("umt5", "UMT5Model"),
- ("unispeech", "UniSpeechModel"),
- ("unispeech-sat", "UniSpeechSatModel"),
- ("univnet", "UnivNetModel"),
- ("van", "VanModel"),
- ("videomae", "VideoMAEModel"),
- ("vilt", "ViltModel"),
- ("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
- ("visual_bert", "VisualBertModel"),
- ("vit", "ViTModel"),
- ("vit_hybrid", "ViTHybridModel"),
- ("vit_mae", "ViTMAEModel"),
- ("vit_msn", "ViTMSNModel"),
- ("vitdet", "VitDetModel"),
- ("vits", "VitsModel"),
- ("vivit", "VivitModel"),
- ("wav2vec2", "Wav2Vec2Model"),
- ("wav2vec2-bert", "Wav2Vec2BertModel"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
- ("wavlm", "WavLMModel"),
- ("whisper", "WhisperModel"),
- ("xclip", "XCLIPModel"),
- ("xglm", "XGLMModel"),
- ("xlm", "XLMModel"),
- ("xlm-prophetnet", "XLMProphetNetModel"),
- ("xlm-roberta", "XLMRobertaModel"),
- ("xlm-roberta-xl", "XLMRobertaXLModel"),
- ("xlnet", "XLNetModel"),
- ("xmod", "XmodModel"),
- ("yolos", "YolosModel"),
- ("yoso", "YosoModel"),
- ("zamba", "ZambaModel"),
- ]
- )
- MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
- [
- # Model for pre-training mapping
- ("albert", "AlbertForPreTraining"),
- ("bart", "BartForConditionalGeneration"),
- ("bert", "BertForPreTraining"),
- ("big_bird", "BigBirdForPreTraining"),
- ("bloom", "BloomForCausalLM"),
- ("camembert", "CamembertForMaskedLM"),
- ("ctrl", "CTRLLMHeadModel"),
- ("data2vec-text", "Data2VecTextForMaskedLM"),
- ("deberta", "DebertaForMaskedLM"),
- ("deberta-v2", "DebertaV2ForMaskedLM"),
- ("distilbert", "DistilBertForMaskedLM"),
- ("electra", "ElectraForPreTraining"),
- ("ernie", "ErnieForPreTraining"),
- ("falcon_mamba", "FalconMambaForCausalLM"),
- ("flaubert", "FlaubertWithLMHeadModel"),
- ("flava", "FlavaForPreTraining"),
- ("fnet", "FNetForPreTraining"),
- ("fsmt", "FSMTForConditionalGeneration"),
- ("funnel", "FunnelForPreTraining"),
- ("gpt-sw3", "GPT2LMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
- ("hiera", "HieraForPreTraining"),
- ("ibert", "IBertForMaskedLM"),
- ("idefics", "IdeficsForVisionText2Text"),
- ("idefics2", "Idefics2ForConditionalGeneration"),
- ("idefics3", "Idefics3ForConditionalGeneration"),
- ("layoutlm", "LayoutLMForMaskedLM"),
- ("llava", "LlavaForConditionalGeneration"),
- ("llava_next", "LlavaNextForConditionalGeneration"),
- ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
- ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
- ("longformer", "LongformerForMaskedLM"),
- ("luke", "LukeForMaskedLM"),
- ("lxmert", "LxmertForPreTraining"),
- ("mamba", "MambaForCausalLM"),
- ("mamba2", "Mamba2ForCausalLM"),
- ("mega", "MegaForMaskedLM"),
- ("megatron-bert", "MegatronBertForPreTraining"),
- ("mllama", "MllamaForConditionalGeneration"),
- ("mobilebert", "MobileBertForPreTraining"),
- ("mpnet", "MPNetForMaskedLM"),
- ("mpt", "MptForCausalLM"),
- ("mra", "MraForMaskedLM"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nezha", "NezhaForPreTraining"),
- ("nllb-moe", "NllbMoeForConditionalGeneration"),
- ("openai-gpt", "OpenAIGPTLMHeadModel"),
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
- ("retribert", "RetriBertModel"),
- ("roberta", "RobertaForMaskedLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
- ("roc_bert", "RoCBertForPreTraining"),
- ("rwkv", "RwkvForCausalLM"),
- ("splinter", "SplinterForPreTraining"),
- ("squeezebert", "SqueezeBertForMaskedLM"),
- ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
- ("t5", "T5ForConditionalGeneration"),
- ("tapas", "TapasForMaskedLM"),
- ("transfo-xl", "TransfoXLLMHeadModel"),
- ("tvlt", "TvltForPreTraining"),
- ("unispeech", "UniSpeechForPreTraining"),
- ("unispeech-sat", "UniSpeechSatForPreTraining"),
- ("video_llava", "VideoLlavaForConditionalGeneration"),
- ("videomae", "VideoMAEForPreTraining"),
- ("vipllava", "VipLlavaForConditionalGeneration"),
- ("visual_bert", "VisualBertForPreTraining"),
- ("vit_mae", "ViTMAEForPreTraining"),
- ("wav2vec2", "Wav2Vec2ForPreTraining"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForMaskedLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
- ("xlnet", "XLNetLMHeadModel"),
- ("xmod", "XmodForMaskedLM"),
- ]
- )
- MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
- [
- # Model with LM heads mapping
- ("albert", "AlbertForMaskedLM"),
- ("bart", "BartForConditionalGeneration"),
- ("bert", "BertForMaskedLM"),
- ("big_bird", "BigBirdForMaskedLM"),
- ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
- ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
- ("bloom", "BloomForCausalLM"),
- ("camembert", "CamembertForMaskedLM"),
- ("codegen", "CodeGenForCausalLM"),
- ("convbert", "ConvBertForMaskedLM"),
- ("cpmant", "CpmAntForCausalLM"),
- ("ctrl", "CTRLLMHeadModel"),
- ("data2vec-text", "Data2VecTextForMaskedLM"),
- ("deberta", "DebertaForMaskedLM"),
- ("deberta-v2", "DebertaV2ForMaskedLM"),
- ("distilbert", "DistilBertForMaskedLM"),
- ("electra", "ElectraForMaskedLM"),
- ("encoder-decoder", "EncoderDecoderModel"),
- ("ernie", "ErnieForMaskedLM"),
- ("esm", "EsmForMaskedLM"),
- ("falcon_mamba", "FalconMambaForCausalLM"),
- ("flaubert", "FlaubertWithLMHeadModel"),
- ("fnet", "FNetForMaskedLM"),
- ("fsmt", "FSMTForConditionalGeneration"),
- ("funnel", "FunnelForMaskedLM"),
- ("git", "GitForCausalLM"),
- ("gpt-sw3", "GPT2LMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- ("gpt_neo", "GPTNeoForCausalLM"),
- ("gpt_neox", "GPTNeoXForCausalLM"),
- ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
- ("gptj", "GPTJForCausalLM"),
- ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
- ("ibert", "IBertForMaskedLM"),
- ("layoutlm", "LayoutLMForMaskedLM"),
- ("led", "LEDForConditionalGeneration"),
- ("longformer", "LongformerForMaskedLM"),
- ("longt5", "LongT5ForConditionalGeneration"),
- ("luke", "LukeForMaskedLM"),
- ("m2m_100", "M2M100ForConditionalGeneration"),
- ("mamba", "MambaForCausalLM"),
- ("mamba2", "Mamba2ForCausalLM"),
- ("marian", "MarianMTModel"),
- ("mega", "MegaForMaskedLM"),
- ("megatron-bert", "MegatronBertForCausalLM"),
- ("mobilebert", "MobileBertForMaskedLM"),
- ("mpnet", "MPNetForMaskedLM"),
- ("mpt", "MptForCausalLM"),
- ("mra", "MraForMaskedLM"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nezha", "NezhaForMaskedLM"),
- ("nllb-moe", "NllbMoeForConditionalGeneration"),
- ("nystromformer", "NystromformerForMaskedLM"),
- ("openai-gpt", "OpenAIGPTLMHeadModel"),
- ("pegasus_x", "PegasusXForConditionalGeneration"),
- ("plbart", "PLBartForConditionalGeneration"),
- ("pop2piano", "Pop2PianoForConditionalGeneration"),
- ("qdqbert", "QDQBertForMaskedLM"),
- ("reformer", "ReformerModelWithLMHead"),
- ("rembert", "RemBertForMaskedLM"),
- ("roberta", "RobertaForMaskedLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
- ("roc_bert", "RoCBertForMaskedLM"),
- ("roformer", "RoFormerForMaskedLM"),
- ("rwkv", "RwkvForCausalLM"),
- ("speech_to_text", "Speech2TextForConditionalGeneration"),
- ("squeezebert", "SqueezeBertForMaskedLM"),
- ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
- ("t5", "T5ForConditionalGeneration"),
- ("tapas", "TapasForMaskedLM"),
- ("transfo-xl", "TransfoXLLMHeadModel"),
- ("wav2vec2", "Wav2Vec2ForMaskedLM"),
- ("whisper", "WhisperForConditionalGeneration"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForMaskedLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
- ("xlnet", "XLNetLMHeadModel"),
- ("xmod", "XmodForMaskedLM"),
- ("yoso", "YosoForMaskedLM"),
- ]
- )
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Causal LM mapping
- ("bart", "BartForCausalLM"),
- ("bert", "BertLMHeadModel"),
- ("bert-generation", "BertGenerationDecoder"),
- ("big_bird", "BigBirdForCausalLM"),
- ("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
- ("biogpt", "BioGptForCausalLM"),
- ("blenderbot", "BlenderbotForCausalLM"),
- ("blenderbot-small", "BlenderbotSmallForCausalLM"),
- ("bloom", "BloomForCausalLM"),
- ("camembert", "CamembertForCausalLM"),
- ("code_llama", "LlamaForCausalLM"),
- ("codegen", "CodeGenForCausalLM"),
- ("cohere", "CohereForCausalLM"),
- ("cpmant", "CpmAntForCausalLM"),
- ("ctrl", "CTRLLMHeadModel"),
- ("data2vec-text", "Data2VecTextForCausalLM"),
- ("dbrx", "DbrxForCausalLM"),
- ("electra", "ElectraForCausalLM"),
- ("ernie", "ErnieForCausalLM"),
- ("falcon", "FalconForCausalLM"),
- ("falcon_mamba", "FalconMambaForCausalLM"),
- ("fuyu", "FuyuForCausalLM"),
- ("gemma", "GemmaForCausalLM"),
- ("gemma2", "Gemma2ForCausalLM"),
- ("git", "GitForCausalLM"),
- ("glm", "GlmForCausalLM"),
- ("gpt-sw3", "GPT2LMHeadModel"),
- ("gpt2", "GPT2LMHeadModel"),
- ("gpt_bigcode", "GPTBigCodeForCausalLM"),
- ("gpt_neo", "GPTNeoForCausalLM"),
- ("gpt_neox", "GPTNeoXForCausalLM"),
- ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
- ("gptj", "GPTJForCausalLM"),
- ("granite", "GraniteForCausalLM"),
- ("granitemoe", "GraniteMoeForCausalLM"),
- ("jamba", "JambaForCausalLM"),
- ("jetmoe", "JetMoeForCausalLM"),
- ("llama", "LlamaForCausalLM"),
- ("mamba", "MambaForCausalLM"),
- ("mamba2", "Mamba2ForCausalLM"),
- ("marian", "MarianForCausalLM"),
- ("mbart", "MBartForCausalLM"),
- ("mega", "MegaForCausalLM"),
- ("megatron-bert", "MegatronBertForCausalLM"),
- ("mistral", "MistralForCausalLM"),
- ("mixtral", "MixtralForCausalLM"),
- ("mllama", "MllamaForCausalLM"),
- ("moshi", "MoshiForCausalLM"),
- ("mpt", "MptForCausalLM"),
- ("musicgen", "MusicgenForCausalLM"),
- ("musicgen_melody", "MusicgenMelodyForCausalLM"),
- ("mvp", "MvpForCausalLM"),
- ("nemotron", "NemotronForCausalLM"),
- ("olmo", "OlmoForCausalLM"),
- ("olmoe", "OlmoeForCausalLM"),
- ("open-llama", "OpenLlamaForCausalLM"),
- ("openai-gpt", "OpenAIGPTLMHeadModel"),
- ("opt", "OPTForCausalLM"),
- ("pegasus", "PegasusForCausalLM"),
- ("persimmon", "PersimmonForCausalLM"),
- ("phi", "PhiForCausalLM"),
- ("phi3", "Phi3ForCausalLM"),
- ("phimoe", "PhimoeForCausalLM"),
- ("plbart", "PLBartForCausalLM"),
- ("prophetnet", "ProphetNetForCausalLM"),
- ("qdqbert", "QDQBertLMHeadModel"),
- ("qwen2", "Qwen2ForCausalLM"),
- ("qwen2_moe", "Qwen2MoeForCausalLM"),
- ("recurrent_gemma", "RecurrentGemmaForCausalLM"),
- ("reformer", "ReformerModelWithLMHead"),
- ("rembert", "RemBertForCausalLM"),
- ("roberta", "RobertaForCausalLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
- ("roc_bert", "RoCBertForCausalLM"),
- ("roformer", "RoFormerForCausalLM"),
- ("rwkv", "RwkvForCausalLM"),
- ("speech_to_text_2", "Speech2Text2ForCausalLM"),
- ("stablelm", "StableLmForCausalLM"),
- ("starcoder2", "Starcoder2ForCausalLM"),
- ("transfo-xl", "TransfoXLLMHeadModel"),
- ("trocr", "TrOCRForCausalLM"),
- ("whisper", "WhisperForCausalLM"),
- ("xglm", "XGLMForCausalLM"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-prophetnet", "XLMProphetNetForCausalLM"),
- ("xlm-roberta", "XLMRobertaForCausalLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
- ("xlnet", "XLNetLMHeadModel"),
- ("xmod", "XmodForCausalLM"),
- ("zamba", "ZambaForCausalLM"),
- ]
- )
- MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image mapping
- ("beit", "BeitModel"),
- ("bit", "BitModel"),
- ("conditional_detr", "ConditionalDetrModel"),
- ("convnext", "ConvNextModel"),
- ("convnextv2", "ConvNextV2Model"),
- ("data2vec-vision", "Data2VecVisionModel"),
- ("deformable_detr", "DeformableDetrModel"),
- ("deit", "DeiTModel"),
- ("deta", "DetaModel"),
- ("detr", "DetrModel"),
- ("dinat", "DinatModel"),
- ("dinov2", "Dinov2Model"),
- ("dpt", "DPTModel"),
- ("efficientformer", "EfficientFormerModel"),
- ("efficientnet", "EfficientNetModel"),
- ("focalnet", "FocalNetModel"),
- ("glpn", "GLPNModel"),
- ("hiera", "HieraModel"),
- ("imagegpt", "ImageGPTModel"),
- ("levit", "LevitModel"),
- ("mllama", "MllamaVisionModel"),
- ("mobilenet_v1", "MobileNetV1Model"),
- ("mobilenet_v2", "MobileNetV2Model"),
- ("mobilevit", "MobileViTModel"),
- ("mobilevitv2", "MobileViTV2Model"),
- ("nat", "NatModel"),
- ("poolformer", "PoolFormerModel"),
- ("pvt", "PvtModel"),
- ("regnet", "RegNetModel"),
- ("resnet", "ResNetModel"),
- ("segformer", "SegformerModel"),
- ("siglip_vision_model", "SiglipVisionModel"),
- ("swiftformer", "SwiftFormerModel"),
- ("swin", "SwinModel"),
- ("swin2sr", "Swin2SRModel"),
- ("swinv2", "Swinv2Model"),
- ("table-transformer", "TableTransformerModel"),
- ("timesformer", "TimesformerModel"),
- ("timm_backbone", "TimmBackbone"),
- ("van", "VanModel"),
- ("videomae", "VideoMAEModel"),
- ("vit", "ViTModel"),
- ("vit_hybrid", "ViTHybridModel"),
- ("vit_mae", "ViTMAEModel"),
- ("vit_msn", "ViTMSNModel"),
- ("vitdet", "VitDetModel"),
- ("vivit", "VivitModel"),
- ("yolos", "YolosModel"),
- ]
- )
- MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
- [
- ("deit", "DeiTForMaskedImageModeling"),
- ("focalnet", "FocalNetForMaskedImageModeling"),
- ("swin", "SwinForMaskedImageModeling"),
- ("swinv2", "Swinv2ForMaskedImageModeling"),
- ("vit", "ViTForMaskedImageModeling"),
- ]
- )
- MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
- # Model for Causal Image Modeling mapping
- [
- ("imagegpt", "ImageGPTForCausalImageModeling"),
- ]
- )
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image Classification mapping
- ("beit", "BeitForImageClassification"),
- ("bit", "BitForImageClassification"),
- ("clip", "CLIPForImageClassification"),
- ("convnext", "ConvNextForImageClassification"),
- ("convnextv2", "ConvNextV2ForImageClassification"),
- ("cvt", "CvtForImageClassification"),
- ("data2vec-vision", "Data2VecVisionForImageClassification"),
- (
- "deit",
- ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
- ),
- ("dinat", "DinatForImageClassification"),
- ("dinov2", "Dinov2ForImageClassification"),
- (
- "efficientformer",
- (
- "EfficientFormerForImageClassification",
- "EfficientFormerForImageClassificationWithTeacher",
- ),
- ),
- ("efficientnet", "EfficientNetForImageClassification"),
- ("focalnet", "FocalNetForImageClassification"),
- ("hiera", "HieraForImageClassification"),
- ("imagegpt", "ImageGPTForImageClassification"),
- (
- "levit",
- ("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
- ),
- ("mobilenet_v1", "MobileNetV1ForImageClassification"),
- ("mobilenet_v2", "MobileNetV2ForImageClassification"),
- ("mobilevit", "MobileViTForImageClassification"),
- ("mobilevitv2", "MobileViTV2ForImageClassification"),
- ("nat", "NatForImageClassification"),
- (
- "perceiver",
- (
- "PerceiverForImageClassificationLearned",
- "PerceiverForImageClassificationFourier",
- "PerceiverForImageClassificationConvProcessing",
- ),
- ),
- ("poolformer", "PoolFormerForImageClassification"),
- ("pvt", "PvtForImageClassification"),
- ("pvt_v2", "PvtV2ForImageClassification"),
- ("regnet", "RegNetForImageClassification"),
- ("resnet", "ResNetForImageClassification"),
- ("segformer", "SegformerForImageClassification"),
- ("siglip", "SiglipForImageClassification"),
- ("swiftformer", "SwiftFormerForImageClassification"),
- ("swin", "SwinForImageClassification"),
- ("swinv2", "Swinv2ForImageClassification"),
- ("van", "VanForImageClassification"),
- ("vit", "ViTForImageClassification"),
- ("vit_hybrid", "ViTHybridForImageClassification"),
- ("vit_msn", "ViTMSNForImageClassification"),
- ]
- )
- MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Do not add new models here, this class will be deprecated in the future.
- # Model for Image Segmentation mapping
- ("detr", "DetrForSegmentation"),
- ]
- )
- MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Semantic Segmentation mapping
- ("beit", "BeitForSemanticSegmentation"),
- ("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
- ("dpt", "DPTForSemanticSegmentation"),
- ("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
- ("mobilevit", "MobileViTForSemanticSegmentation"),
- ("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
- ("segformer", "SegformerForSemanticSegmentation"),
- ("upernet", "UperNetForSemanticSegmentation"),
- ]
- )
- MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Instance Segmentation mapping
- # MaskFormerForInstanceSegmentation can be removed from this mapping in v5
- ("maskformer", "MaskFormerForInstanceSegmentation"),
- ]
- )
- MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Universal Segmentation mapping
- ("detr", "DetrForSegmentation"),
- ("mask2former", "Mask2FormerForUniversalSegmentation"),
- ("maskformer", "MaskFormerForInstanceSegmentation"),
- ("oneformer", "OneFormerForUniversalSegmentation"),
- ]
- )
- MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- ("timesformer", "TimesformerForVideoClassification"),
- ("videomae", "VideoMAEForVideoClassification"),
- ("vivit", "VivitForVideoClassification"),
- ]
- )
- MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("blip", "BlipForConditionalGeneration"),
- ("blip-2", "Blip2ForConditionalGeneration"),
- ("chameleon", "ChameleonForConditionalGeneration"),
- ("git", "GitForCausalLM"),
- ("idefics2", "Idefics2ForConditionalGeneration"),
- ("idefics3", "Idefics3ForConditionalGeneration"),
- ("instructblip", "InstructBlipForConditionalGeneration"),
- ("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
- ("kosmos-2", "Kosmos2ForConditionalGeneration"),
- ("llava", "LlavaForConditionalGeneration"),
- ("llava_next", "LlavaNextForConditionalGeneration"),
- ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
- ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
- ("mllama", "MllamaForConditionalGeneration"),
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- ("pix2struct", "Pix2StructForConditionalGeneration"),
- ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
- ("video_llava", "VideoLlavaForConditionalGeneration"),
- ("vipllava", "VipLlavaForConditionalGeneration"),
- ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
- ]
- )
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
- [
- ("blip", "BlipForConditionalGeneration"),
- ("blip-2", "Blip2ForConditionalGeneration"),
- ("chameleon", "ChameleonForConditionalGeneration"),
- ("fuyu", "FuyuForCausalLM"),
- ("git", "GitForCausalLM"),
- ("idefics", "IdeficsForVisionText2Text"),
- ("idefics2", "Idefics2ForConditionalGeneration"),
- ("idefics3", "Idefics3ForConditionalGeneration"),
- ("instructblip", "InstructBlipForConditionalGeneration"),
- ("kosmos-2", "Kosmos2ForConditionalGeneration"),
- ("llava", "LlavaForConditionalGeneration"),
- ("llava_next", "LlavaNextForConditionalGeneration"),
- ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
- ("mllama", "MllamaForConditionalGeneration"),
- ("paligemma", "PaliGemmaForConditionalGeneration"),
- ("pix2struct", "Pix2StructForConditionalGeneration"),
- ("pixtral", "LlavaForConditionalGeneration"),
- ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
- ("udop", "UdopForConditionalGeneration"),
- ("vipllava", "VipLlavaForConditionalGeneration"),
- ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
- ]
- )
- MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Masked LM mapping
- ("albert", "AlbertForMaskedLM"),
- ("bart", "BartForConditionalGeneration"),
- ("bert", "BertForMaskedLM"),
- ("big_bird", "BigBirdForMaskedLM"),
- ("camembert", "CamembertForMaskedLM"),
- ("convbert", "ConvBertForMaskedLM"),
- ("data2vec-text", "Data2VecTextForMaskedLM"),
- ("deberta", "DebertaForMaskedLM"),
- ("deberta-v2", "DebertaV2ForMaskedLM"),
- ("distilbert", "DistilBertForMaskedLM"),
- ("electra", "ElectraForMaskedLM"),
- ("ernie", "ErnieForMaskedLM"),
- ("esm", "EsmForMaskedLM"),
- ("flaubert", "FlaubertWithLMHeadModel"),
- ("fnet", "FNetForMaskedLM"),
- ("funnel", "FunnelForMaskedLM"),
- ("ibert", "IBertForMaskedLM"),
- ("layoutlm", "LayoutLMForMaskedLM"),
- ("longformer", "LongformerForMaskedLM"),
- ("luke", "LukeForMaskedLM"),
- ("mbart", "MBartForConditionalGeneration"),
- ("mega", "MegaForMaskedLM"),
- ("megatron-bert", "MegatronBertForMaskedLM"),
- ("mobilebert", "MobileBertForMaskedLM"),
- ("mpnet", "MPNetForMaskedLM"),
- ("mra", "MraForMaskedLM"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nezha", "NezhaForMaskedLM"),
- ("nystromformer", "NystromformerForMaskedLM"),
- ("perceiver", "PerceiverForMaskedLM"),
- ("qdqbert", "QDQBertForMaskedLM"),
- ("reformer", "ReformerForMaskedLM"),
- ("rembert", "RemBertForMaskedLM"),
- ("roberta", "RobertaForMaskedLM"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
- ("roc_bert", "RoCBertForMaskedLM"),
- ("roformer", "RoFormerForMaskedLM"),
- ("squeezebert", "SqueezeBertForMaskedLM"),
- ("tapas", "TapasForMaskedLM"),
- ("wav2vec2", "Wav2Vec2ForMaskedLM"),
- ("xlm", "XLMWithLMHeadModel"),
- ("xlm-roberta", "XLMRobertaForMaskedLM"),
- ("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
- ("xmod", "XmodForMaskedLM"),
- ("yoso", "YosoForMaskedLM"),
- ]
- )
- MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Object Detection mapping
- ("conditional_detr", "ConditionalDetrForObjectDetection"),
- ("deformable_detr", "DeformableDetrForObjectDetection"),
- ("deta", "DetaForObjectDetection"),
- ("detr", "DetrForObjectDetection"),
- ("rt_detr", "RTDetrForObjectDetection"),
- ("table-transformer", "TableTransformerForObjectDetection"),
- ("yolos", "YolosForObjectDetection"),
- ]
- )
- MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Zero Shot Object Detection mapping
- ("grounding-dino", "GroundingDinoForObjectDetection"),
- ("omdet-turbo", "OmDetTurboForObjectDetection"),
- ("owlv2", "Owlv2ForObjectDetection"),
- ("owlvit", "OwlViTForObjectDetection"),
- ]
- )
- MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for depth estimation mapping
- ("depth_anything", "DepthAnythingForDepthEstimation"),
- ("dpt", "DPTForDepthEstimation"),
- ("glpn", "GLPNForDepthEstimation"),
- ("zoedepth", "ZoeDepthForDepthEstimation"),
- ]
- )
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Seq2Seq Causal LM mapping
- ("bart", "BartForConditionalGeneration"),
- ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
- ("blenderbot", "BlenderbotForConditionalGeneration"),
- ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
- ("encoder-decoder", "EncoderDecoderModel"),
- ("fsmt", "FSMTForConditionalGeneration"),
- ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
- ("led", "LEDForConditionalGeneration"),
- ("longt5", "LongT5ForConditionalGeneration"),
- ("m2m_100", "M2M100ForConditionalGeneration"),
- ("marian", "MarianMTModel"),
- ("mbart", "MBartForConditionalGeneration"),
- ("mt5", "MT5ForConditionalGeneration"),
- ("mvp", "MvpForConditionalGeneration"),
- ("nllb-moe", "NllbMoeForConditionalGeneration"),
- ("pegasus", "PegasusForConditionalGeneration"),
- ("pegasus_x", "PegasusXForConditionalGeneration"),
- ("plbart", "PLBartForConditionalGeneration"),
- ("prophetnet", "ProphetNetForConditionalGeneration"),
- ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForTextToText"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
- ("switch_transformers", "SwitchTransformersForConditionalGeneration"),
- ("t5", "T5ForConditionalGeneration"),
- ("umt5", "UMT5ForConditionalGeneration"),
- ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
- ]
- )
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("pop2piano", "Pop2PianoForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForSpeechToText"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
- ("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
- ("speech_to_text", "Speech2TextForConditionalGeneration"),
- ("speecht5", "SpeechT5ForSpeechToText"),
- ("whisper", "WhisperForConditionalGeneration"),
- ]
- )
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Sequence Classification mapping
- ("albert", "AlbertForSequenceClassification"),
- ("bart", "BartForSequenceClassification"),
- ("bert", "BertForSequenceClassification"),
- ("big_bird", "BigBirdForSequenceClassification"),
- ("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
- ("biogpt", "BioGptForSequenceClassification"),
- ("bloom", "BloomForSequenceClassification"),
- ("camembert", "CamembertForSequenceClassification"),
- ("canine", "CanineForSequenceClassification"),
- ("code_llama", "LlamaForSequenceClassification"),
- ("convbert", "ConvBertForSequenceClassification"),
- ("ctrl", "CTRLForSequenceClassification"),
- ("data2vec-text", "Data2VecTextForSequenceClassification"),
- ("deberta", "DebertaForSequenceClassification"),
- ("deberta-v2", "DebertaV2ForSequenceClassification"),
- ("distilbert", "DistilBertForSequenceClassification"),
- ("electra", "ElectraForSequenceClassification"),
- ("ernie", "ErnieForSequenceClassification"),
- ("ernie_m", "ErnieMForSequenceClassification"),
- ("esm", "EsmForSequenceClassification"),
- ("falcon", "FalconForSequenceClassification"),
- ("flaubert", "FlaubertForSequenceClassification"),
- ("fnet", "FNetForSequenceClassification"),
- ("funnel", "FunnelForSequenceClassification"),
- ("gemma", "GemmaForSequenceClassification"),
- ("gemma2", "Gemma2ForSequenceClassification"),
- ("glm", "GlmForSequenceClassification"),
- ("gpt-sw3", "GPT2ForSequenceClassification"),
- ("gpt2", "GPT2ForSequenceClassification"),
- ("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
- ("gpt_neo", "GPTNeoForSequenceClassification"),
- ("gpt_neox", "GPTNeoXForSequenceClassification"),
- ("gptj", "GPTJForSequenceClassification"),
- ("ibert", "IBertForSequenceClassification"),
- ("jamba", "JambaForSequenceClassification"),
- ("jetmoe", "JetMoeForSequenceClassification"),
- ("layoutlm", "LayoutLMForSequenceClassification"),
- ("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
- ("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
- ("led", "LEDForSequenceClassification"),
- ("lilt", "LiltForSequenceClassification"),
- ("llama", "LlamaForSequenceClassification"),
- ("longformer", "LongformerForSequenceClassification"),
- ("luke", "LukeForSequenceClassification"),
- ("markuplm", "MarkupLMForSequenceClassification"),
- ("mbart", "MBartForSequenceClassification"),
- ("mega", "MegaForSequenceClassification"),
- ("megatron-bert", "MegatronBertForSequenceClassification"),
- ("mistral", "MistralForSequenceClassification"),
- ("mixtral", "MixtralForSequenceClassification"),
- ("mobilebert", "MobileBertForSequenceClassification"),
- ("mpnet", "MPNetForSequenceClassification"),
- ("mpt", "MptForSequenceClassification"),
- ("mra", "MraForSequenceClassification"),
- ("mt5", "MT5ForSequenceClassification"),
- ("mvp", "MvpForSequenceClassification"),
- ("nemotron", "NemotronForSequenceClassification"),
- ("nezha", "NezhaForSequenceClassification"),
- ("nystromformer", "NystromformerForSequenceClassification"),
- ("open-llama", "OpenLlamaForSequenceClassification"),
- ("openai-gpt", "OpenAIGPTForSequenceClassification"),
- ("opt", "OPTForSequenceClassification"),
- ("perceiver", "PerceiverForSequenceClassification"),
- ("persimmon", "PersimmonForSequenceClassification"),
- ("phi", "PhiForSequenceClassification"),
- ("phi3", "Phi3ForSequenceClassification"),
- ("phimoe", "PhimoeForSequenceClassification"),
- ("plbart", "PLBartForSequenceClassification"),
- ("qdqbert", "QDQBertForSequenceClassification"),
- ("qwen2", "Qwen2ForSequenceClassification"),
- ("qwen2_moe", "Qwen2MoeForSequenceClassification"),
- ("reformer", "ReformerForSequenceClassification"),
- ("rembert", "RemBertForSequenceClassification"),
- ("roberta", "RobertaForSequenceClassification"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
- ("roc_bert", "RoCBertForSequenceClassification"),
- ("roformer", "RoFormerForSequenceClassification"),
- ("squeezebert", "SqueezeBertForSequenceClassification"),
- ("stablelm", "StableLmForSequenceClassification"),
- ("starcoder2", "Starcoder2ForSequenceClassification"),
- ("t5", "T5ForSequenceClassification"),
- ("tapas", "TapasForSequenceClassification"),
- ("transfo-xl", "TransfoXLForSequenceClassification"),
- ("umt5", "UMT5ForSequenceClassification"),
- ("xlm", "XLMForSequenceClassification"),
- ("xlm-roberta", "XLMRobertaForSequenceClassification"),
- ("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
- ("xlnet", "XLNetForSequenceClassification"),
- ("xmod", "XmodForSequenceClassification"),
- ("yoso", "YosoForSequenceClassification"),
- ("zamba", "ZambaForSequenceClassification"),
- ]
- )
- MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Question Answering mapping
- ("albert", "AlbertForQuestionAnswering"),
- ("bart", "BartForQuestionAnswering"),
- ("bert", "BertForQuestionAnswering"),
- ("big_bird", "BigBirdForQuestionAnswering"),
- ("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
- ("bloom", "BloomForQuestionAnswering"),
- ("camembert", "CamembertForQuestionAnswering"),
- ("canine", "CanineForQuestionAnswering"),
- ("convbert", "ConvBertForQuestionAnswering"),
- ("data2vec-text", "Data2VecTextForQuestionAnswering"),
- ("deberta", "DebertaForQuestionAnswering"),
- ("deberta-v2", "DebertaV2ForQuestionAnswering"),
- ("distilbert", "DistilBertForQuestionAnswering"),
- ("electra", "ElectraForQuestionAnswering"),
- ("ernie", "ErnieForQuestionAnswering"),
- ("ernie_m", "ErnieMForQuestionAnswering"),
- ("falcon", "FalconForQuestionAnswering"),
- ("flaubert", "FlaubertForQuestionAnsweringSimple"),
- ("fnet", "FNetForQuestionAnswering"),
- ("funnel", "FunnelForQuestionAnswering"),
- ("gpt2", "GPT2ForQuestionAnswering"),
- ("gpt_neo", "GPTNeoForQuestionAnswering"),
- ("gpt_neox", "GPTNeoXForQuestionAnswering"),
- ("gptj", "GPTJForQuestionAnswering"),
- ("ibert", "IBertForQuestionAnswering"),
- ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
- ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
- ("led", "LEDForQuestionAnswering"),
- ("lilt", "LiltForQuestionAnswering"),
- ("llama", "LlamaForQuestionAnswering"),
- ("longformer", "LongformerForQuestionAnswering"),
- ("luke", "LukeForQuestionAnswering"),
- ("lxmert", "LxmertForQuestionAnswering"),
- ("markuplm", "MarkupLMForQuestionAnswering"),
- ("mbart", "MBartForQuestionAnswering"),
- ("mega", "MegaForQuestionAnswering"),
- ("megatron-bert", "MegatronBertForQuestionAnswering"),
- ("mistral", "MistralForQuestionAnswering"),
- ("mixtral", "MixtralForQuestionAnswering"),
- ("mobilebert", "MobileBertForQuestionAnswering"),
- ("mpnet", "MPNetForQuestionAnswering"),
- ("mpt", "MptForQuestionAnswering"),
- ("mra", "MraForQuestionAnswering"),
- ("mt5", "MT5ForQuestionAnswering"),
- ("mvp", "MvpForQuestionAnswering"),
- ("nemotron", "NemotronForQuestionAnswering"),
- ("nezha", "NezhaForQuestionAnswering"),
- ("nystromformer", "NystromformerForQuestionAnswering"),
- ("opt", "OPTForQuestionAnswering"),
- ("qdqbert", "QDQBertForQuestionAnswering"),
- ("qwen2", "Qwen2ForQuestionAnswering"),
- ("qwen2_moe", "Qwen2MoeForQuestionAnswering"),
- ("reformer", "ReformerForQuestionAnswering"),
- ("rembert", "RemBertForQuestionAnswering"),
- ("roberta", "RobertaForQuestionAnswering"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
- ("roc_bert", "RoCBertForQuestionAnswering"),
- ("roformer", "RoFormerForQuestionAnswering"),
- ("splinter", "SplinterForQuestionAnswering"),
- ("squeezebert", "SqueezeBertForQuestionAnswering"),
- ("t5", "T5ForQuestionAnswering"),
- ("umt5", "UMT5ForQuestionAnswering"),
- ("xlm", "XLMForQuestionAnsweringSimple"),
- ("xlm-roberta", "XLMRobertaForQuestionAnswering"),
- ("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
- ("xlnet", "XLNetForQuestionAnsweringSimple"),
- ("xmod", "XmodForQuestionAnswering"),
- ("yoso", "YosoForQuestionAnswering"),
- ]
- )
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Table Question Answering mapping
- ("tapas", "TapasForQuestionAnswering"),
- ]
- )
- MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- ("blip", "BlipForQuestionAnswering"),
- ("blip-2", "Blip2ForConditionalGeneration"),
- ("vilt", "ViltForQuestionAnswering"),
- ]
- )
- MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- ("layoutlm", "LayoutLMForQuestionAnswering"),
- ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
- ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
- ]
- )
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Token Classification mapping
- ("albert", "AlbertForTokenClassification"),
- ("bert", "BertForTokenClassification"),
- ("big_bird", "BigBirdForTokenClassification"),
- ("biogpt", "BioGptForTokenClassification"),
- ("bloom", "BloomForTokenClassification"),
- ("bros", "BrosForTokenClassification"),
- ("camembert", "CamembertForTokenClassification"),
- ("canine", "CanineForTokenClassification"),
- ("convbert", "ConvBertForTokenClassification"),
- ("data2vec-text", "Data2VecTextForTokenClassification"),
- ("deberta", "DebertaForTokenClassification"),
- ("deberta-v2", "DebertaV2ForTokenClassification"),
- ("distilbert", "DistilBertForTokenClassification"),
- ("electra", "ElectraForTokenClassification"),
- ("ernie", "ErnieForTokenClassification"),
- ("ernie_m", "ErnieMForTokenClassification"),
- ("esm", "EsmForTokenClassification"),
- ("falcon", "FalconForTokenClassification"),
- ("flaubert", "FlaubertForTokenClassification"),
- ("fnet", "FNetForTokenClassification"),
- ("funnel", "FunnelForTokenClassification"),
- ("gemma", "GemmaForTokenClassification"),
- ("gemma2", "Gemma2ForTokenClassification"),
- ("glm", "GlmForTokenClassification"),
- ("gpt-sw3", "GPT2ForTokenClassification"),
- ("gpt2", "GPT2ForTokenClassification"),
- ("gpt_bigcode", "GPTBigCodeForTokenClassification"),
- ("gpt_neo", "GPTNeoForTokenClassification"),
- ("gpt_neox", "GPTNeoXForTokenClassification"),
- ("ibert", "IBertForTokenClassification"),
- ("layoutlm", "LayoutLMForTokenClassification"),
- ("layoutlmv2", "LayoutLMv2ForTokenClassification"),
- ("layoutlmv3", "LayoutLMv3ForTokenClassification"),
- ("lilt", "LiltForTokenClassification"),
- ("llama", "LlamaForTokenClassification"),
- ("longformer", "LongformerForTokenClassification"),
- ("luke", "LukeForTokenClassification"),
- ("markuplm", "MarkupLMForTokenClassification"),
- ("mega", "MegaForTokenClassification"),
- ("megatron-bert", "MegatronBertForTokenClassification"),
- ("mistral", "MistralForTokenClassification"),
- ("mixtral", "MixtralForTokenClassification"),
- ("mobilebert", "MobileBertForTokenClassification"),
- ("mpnet", "MPNetForTokenClassification"),
- ("mpt", "MptForTokenClassification"),
- ("mra", "MraForTokenClassification"),
- ("mt5", "MT5ForTokenClassification"),
- ("nemotron", "NemotronForTokenClassification"),
- ("nezha", "NezhaForTokenClassification"),
- ("nystromformer", "NystromformerForTokenClassification"),
- ("persimmon", "PersimmonForTokenClassification"),
- ("phi", "PhiForTokenClassification"),
- ("phi3", "Phi3ForTokenClassification"),
- ("qdqbert", "QDQBertForTokenClassification"),
- ("qwen2", "Qwen2ForTokenClassification"),
- ("qwen2_moe", "Qwen2MoeForTokenClassification"),
- ("rembert", "RemBertForTokenClassification"),
- ("roberta", "RobertaForTokenClassification"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
- ("roc_bert", "RoCBertForTokenClassification"),
- ("roformer", "RoFormerForTokenClassification"),
- ("squeezebert", "SqueezeBertForTokenClassification"),
- ("stablelm", "StableLmForTokenClassification"),
- ("starcoder2", "Starcoder2ForTokenClassification"),
- ("t5", "T5ForTokenClassification"),
- ("umt5", "UMT5ForTokenClassification"),
- ("xlm", "XLMForTokenClassification"),
- ("xlm-roberta", "XLMRobertaForTokenClassification"),
- ("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
- ("xlnet", "XLNetForTokenClassification"),
- ("xmod", "XmodForTokenClassification"),
- ("yoso", "YosoForTokenClassification"),
- ]
- )
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Multiple Choice mapping
- ("albert", "AlbertForMultipleChoice"),
- ("bert", "BertForMultipleChoice"),
- ("big_bird", "BigBirdForMultipleChoice"),
- ("camembert", "CamembertForMultipleChoice"),
- ("canine", "CanineForMultipleChoice"),
- ("convbert", "ConvBertForMultipleChoice"),
- ("data2vec-text", "Data2VecTextForMultipleChoice"),
- ("deberta-v2", "DebertaV2ForMultipleChoice"),
- ("distilbert", "DistilBertForMultipleChoice"),
- ("electra", "ElectraForMultipleChoice"),
- ("ernie", "ErnieForMultipleChoice"),
- ("ernie_m", "ErnieMForMultipleChoice"),
- ("flaubert", "FlaubertForMultipleChoice"),
- ("fnet", "FNetForMultipleChoice"),
- ("funnel", "FunnelForMultipleChoice"),
- ("ibert", "IBertForMultipleChoice"),
- ("longformer", "LongformerForMultipleChoice"),
- ("luke", "LukeForMultipleChoice"),
- ("mega", "MegaForMultipleChoice"),
- ("megatron-bert", "MegatronBertForMultipleChoice"),
- ("mobilebert", "MobileBertForMultipleChoice"),
- ("mpnet", "MPNetForMultipleChoice"),
- ("mra", "MraForMultipleChoice"),
- ("nezha", "NezhaForMultipleChoice"),
- ("nystromformer", "NystromformerForMultipleChoice"),
- ("qdqbert", "QDQBertForMultipleChoice"),
- ("rembert", "RemBertForMultipleChoice"),
- ("roberta", "RobertaForMultipleChoice"),
- ("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
- ("roc_bert", "RoCBertForMultipleChoice"),
- ("roformer", "RoFormerForMultipleChoice"),
- ("squeezebert", "SqueezeBertForMultipleChoice"),
- ("xlm", "XLMForMultipleChoice"),
- ("xlm-roberta", "XLMRobertaForMultipleChoice"),
- ("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
- ("xlnet", "XLNetForMultipleChoice"),
- ("xmod", "XmodForMultipleChoice"),
- ("yoso", "YosoForMultipleChoice"),
- ]
- )
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("bert", "BertForNextSentencePrediction"),
- ("ernie", "ErnieForNextSentencePrediction"),
- ("fnet", "FNetForNextSentencePrediction"),
- ("megatron-bert", "MegatronBertForNextSentencePrediction"),
- ("mobilebert", "MobileBertForNextSentencePrediction"),
- ("nezha", "NezhaForNextSentencePrediction"),
- ("qdqbert", "QDQBertForNextSentencePrediction"),
- ]
- )
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("audio-spectrogram-transformer", "ASTForAudioClassification"),
- ("data2vec-audio", "Data2VecAudioForSequenceClassification"),
- ("hubert", "HubertForSequenceClassification"),
- ("sew", "SEWForSequenceClassification"),
- ("sew-d", "SEWDForSequenceClassification"),
- ("unispeech", "UniSpeechForSequenceClassification"),
- ("unispeech-sat", "UniSpeechSatForSequenceClassification"),
- ("wav2vec2", "Wav2Vec2ForSequenceClassification"),
- ("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
- ("wavlm", "WavLMForSequenceClassification"),
- ("whisper", "WhisperForAudioClassification"),
- ]
- )
- MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
- [
- # Model for Connectionist temporal classification (CTC) mapping
- ("data2vec-audio", "Data2VecAudioForCTC"),
- ("hubert", "HubertForCTC"),
- ("mctct", "MCTCTForCTC"),
- ("sew", "SEWForCTC"),
- ("sew-d", "SEWDForCTC"),
- ("unispeech", "UniSpeechForCTC"),
- ("unispeech-sat", "UniSpeechSatForCTC"),
- ("wav2vec2", "Wav2Vec2ForCTC"),
- ("wav2vec2-bert", "Wav2Vec2BertForCTC"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
- ("wavlm", "WavLMForCTC"),
- ]
- )
- MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
- ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
- ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
- ("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
- ("wavlm", "WavLMForAudioFrameClassification"),
- ]
- )
- MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
- [
- # Model for Audio Classification mapping
- ("data2vec-audio", "Data2VecAudioForXVector"),
- ("unispeech-sat", "UniSpeechSatForXVector"),
- ("wav2vec2", "Wav2Vec2ForXVector"),
- ("wav2vec2-bert", "Wav2Vec2BertForXVector"),
- ("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
- ("wavlm", "WavLMForXVector"),
- ]
- )
- MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Text-To-Spectrogram mapping
- ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
- ("speecht5", "SpeechT5ForTextToSpeech"),
- ]
- )
- MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Text-To-Waveform mapping
- ("bark", "BarkModel"),
- ("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
- ("musicgen", "MusicgenForConditionalGeneration"),
- ("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
- ("seamless_m4t", "SeamlessM4TForTextToSpeech"),
- ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
- ("vits", "VitsModel"),
- ]
- )
- MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Zero Shot Image Classification mapping
- ("align", "AlignModel"),
- ("altclip", "AltCLIPModel"),
- ("blip", "BlipModel"),
- ("blip-2", "Blip2ForImageTextRetrieval"),
- ("chinese_clip", "ChineseCLIPModel"),
- ("clip", "CLIPModel"),
- ("clipseg", "CLIPSegModel"),
- ("siglip", "SiglipModel"),
- ]
- )
- MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
- [
- # Backbone mapping
- ("beit", "BeitBackbone"),
- ("bit", "BitBackbone"),
- ("convnext", "ConvNextBackbone"),
- ("convnextv2", "ConvNextV2Backbone"),
- ("dinat", "DinatBackbone"),
- ("dinov2", "Dinov2Backbone"),
- ("focalnet", "FocalNetBackbone"),
- ("hiera", "HieraBackbone"),
- ("maskformer-swin", "MaskFormerSwinBackbone"),
- ("nat", "NatBackbone"),
- ("pvt_v2", "PvtV2Backbone"),
- ("resnet", "ResNetBackbone"),
- ("rt_detr_resnet", "RTDetrResNetBackbone"),
- ("swin", "SwinBackbone"),
- ("swinv2", "Swinv2Backbone"),
- ("timm_backbone", "TimmBackbone"),
- ("vitdet", "VitDetBackbone"),
- ]
- )
- MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
- [
- ("sam", "SamModel"),
- ]
- )
- MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
- [
- ("superpoint", "SuperPointForKeypointDetection"),
- ]
- )
- MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
- [
- ("albert", "AlbertModel"),
- ("bert", "BertModel"),
- ("big_bird", "BigBirdModel"),
- ("clip_text_model", "CLIPTextModel"),
- ("data2vec-text", "Data2VecTextModel"),
- ("deberta", "DebertaModel"),
- ("deberta-v2", "DebertaV2Model"),
- ("distilbert", "DistilBertModel"),
- ("electra", "ElectraModel"),
- ("flaubert", "FlaubertModel"),
- ("ibert", "IBertModel"),
- ("longformer", "LongformerModel"),
- ("mllama", "MllamaTextModel"),
- ("mobilebert", "MobileBertModel"),
- ("mt5", "MT5EncoderModel"),
- ("nystromformer", "NystromformerModel"),
- ("reformer", "ReformerModel"),
- ("rembert", "RemBertModel"),
- ("roberta", "RobertaModel"),
- ("roberta-prelayernorm", "RobertaPreLayerNormModel"),
- ("roc_bert", "RoCBertModel"),
- ("roformer", "RoFormerModel"),
- ("squeezebert", "SqueezeBertModel"),
- ("t5", "T5EncoderModel"),
- ("umt5", "UMT5EncoderModel"),
- ("xlm", "XLMModel"),
- ("xlm-roberta", "XLMRobertaModel"),
- ("xlm-roberta-xl", "XLMRobertaXLModel"),
- ]
- )
- MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- ("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
- ("patchtst", "PatchTSTForClassification"),
- ]
- )
- MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
- [
- ("patchtsmixer", "PatchTSMixerForRegression"),
- ("patchtst", "PatchTSTForRegression"),
- ]
- )
- MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
- [
- ("swin2sr", "Swin2SRForImageSuperResolution"),
- ]
- )
- MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
- MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
- MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
- MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
- MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
- )
- MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
- MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
- )
- MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
- MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
- MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
- )
- MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
- MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
- )
- MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
- MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
- )
- MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
- )
- MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
- MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
- )
- MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
- MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
- MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
- MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
- )
- MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
- MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
- MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
- MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
- )
- MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
- MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
- )
- MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
- )
- MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
- class AutoModelForMaskGeneration(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
- class AutoModelForKeypointDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
- class AutoModelForTextEncoding(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
- class AutoModelForImageToImage(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
- class AutoModel(_BaseAutoModelClass):
- _model_mapping = MODEL_MAPPING
- AutoModel = auto_class_update(AutoModel)
- class AutoModelForPreTraining(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_PRETRAINING_MAPPING
- AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
- # Private on purpose, the public class will add the deprecation warnings.
- class _AutoModelWithLMHead(_BaseAutoModelClass):
- _model_mapping = MODEL_WITH_LM_HEAD_MAPPING
- _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
- class AutoModelForCausalLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
- AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
- class AutoModelForMaskedLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
- AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
- class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
- AutoModelForSeq2SeqLM = auto_class_update(
- AutoModelForSeq2SeqLM,
- head_doc="sequence-to-sequence language modeling",
- checkpoint_for_example="google-t5/t5-base",
- )
- class AutoModelForSequenceClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
- AutoModelForSequenceClassification = auto_class_update(
- AutoModelForSequenceClassification, head_doc="sequence classification"
- )
- class AutoModelForQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
- AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
- class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
- AutoModelForTableQuestionAnswering = auto_class_update(
- AutoModelForTableQuestionAnswering,
- head_doc="table question answering",
- checkpoint_for_example="google/tapas-base-finetuned-wtq",
- )
- class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
- AutoModelForVisualQuestionAnswering = auto_class_update(
- AutoModelForVisualQuestionAnswering,
- head_doc="visual question answering",
- checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
- )
- class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
- AutoModelForDocumentQuestionAnswering = auto_class_update(
- AutoModelForDocumentQuestionAnswering,
- head_doc="document question answering",
- checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
- )
- class AutoModelForTokenClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
- AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
- class AutoModelForMultipleChoice(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
- AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
- class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
- AutoModelForNextSentencePrediction = auto_class_update(
- AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
- )
- class AutoModelForImageClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
- AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
- class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
- AutoModelForZeroShotImageClassification = auto_class_update(
- AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
- )
- class AutoModelForImageSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
- AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
- class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
- AutoModelForSemanticSegmentation = auto_class_update(
- AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
- )
- class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
- AutoModelForUniversalSegmentation = auto_class_update(
- AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
- )
- class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
- AutoModelForInstanceSegmentation = auto_class_update(
- AutoModelForInstanceSegmentation, head_doc="instance segmentation"
- )
- class AutoModelForObjectDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
- AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
- class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
- AutoModelForZeroShotObjectDetection = auto_class_update(
- AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
- )
- class AutoModelForDepthEstimation(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
- AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
- class AutoModelForVideoClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
- AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
- class AutoModelForVision2Seq(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
- AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
- class AutoModelForImageTextToText(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
- AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")
- class AutoModelForAudioClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
- AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
- class AutoModelForCTC(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_CTC_MAPPING
- AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
- class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
- AutoModelForSpeechSeq2Seq = auto_class_update(
- AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
- )
- class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
- AutoModelForAudioFrameClassification = auto_class_update(
- AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
- )
- class AutoModelForAudioXVector(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
- class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
- class AutoModelForTextToWaveform(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
- class AutoBackbone(_BaseAutoBackboneClass):
- _model_mapping = MODEL_FOR_BACKBONE_MAPPING
- AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
- class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
- _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
- AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
- class AutoModelWithLMHead(_AutoModelWithLMHead):
- @classmethod
- def from_config(cls, config):
- warnings.warn(
- "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
- "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
- "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_config(config)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- warnings.warn(
- "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
- "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
- "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|