| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727 |
- # coding=utf-8
- # Copyright 2018 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Auto Model class."""
- import warnings
- from collections import OrderedDict
- from ...utils import logging
- from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
- from .configuration_auto import CONFIG_MAPPING_NAMES
- logger = logging.get_logger(__name__)
- TF_MODEL_MAPPING_NAMES = OrderedDict(
- [
- # Base model mapping
- ("albert", "TFAlbertModel"),
- ("bart", "TFBartModel"),
- ("bert", "TFBertModel"),
- ("blenderbot", "TFBlenderbotModel"),
- ("blenderbot-small", "TFBlenderbotSmallModel"),
- ("blip", "TFBlipModel"),
- ("camembert", "TFCamembertModel"),
- ("clip", "TFCLIPModel"),
- ("convbert", "TFConvBertModel"),
- ("convnext", "TFConvNextModel"),
- ("convnextv2", "TFConvNextV2Model"),
- ("ctrl", "TFCTRLModel"),
- ("cvt", "TFCvtModel"),
- ("data2vec-vision", "TFData2VecVisionModel"),
- ("deberta", "TFDebertaModel"),
- ("deberta-v2", "TFDebertaV2Model"),
- ("deit", "TFDeiTModel"),
- ("distilbert", "TFDistilBertModel"),
- ("dpr", "TFDPRQuestionEncoder"),
- ("efficientformer", "TFEfficientFormerModel"),
- ("electra", "TFElectraModel"),
- ("esm", "TFEsmModel"),
- ("flaubert", "TFFlaubertModel"),
- ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
- ("gpt-sw3", "TFGPT2Model"),
- ("gpt2", "TFGPT2Model"),
- ("gptj", "TFGPTJModel"),
- ("groupvit", "TFGroupViTModel"),
- ("hubert", "TFHubertModel"),
- ("idefics", "TFIdeficsModel"),
- ("layoutlm", "TFLayoutLMModel"),
- ("layoutlmv3", "TFLayoutLMv3Model"),
- ("led", "TFLEDModel"),
- ("longformer", "TFLongformerModel"),
- ("lxmert", "TFLxmertModel"),
- ("marian", "TFMarianModel"),
- ("mbart", "TFMBartModel"),
- ("mistral", "TFMistralModel"),
- ("mobilebert", "TFMobileBertModel"),
- ("mobilevit", "TFMobileViTModel"),
- ("mpnet", "TFMPNetModel"),
- ("mt5", "TFMT5Model"),
- ("openai-gpt", "TFOpenAIGPTModel"),
- ("opt", "TFOPTModel"),
- ("pegasus", "TFPegasusModel"),
- ("regnet", "TFRegNetModel"),
- ("rembert", "TFRemBertModel"),
- ("resnet", "TFResNetModel"),
- ("roberta", "TFRobertaModel"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
- ("roformer", "TFRoFormerModel"),
- ("sam", "TFSamModel"),
- ("segformer", "TFSegformerModel"),
- ("speech_to_text", "TFSpeech2TextModel"),
- ("swiftformer", "TFSwiftFormerModel"),
- ("swin", "TFSwinModel"),
- ("t5", "TFT5Model"),
- ("tapas", "TFTapasModel"),
- ("transfo-xl", "TFTransfoXLModel"),
- ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
- ("vit", "TFViTModel"),
- ("vit_mae", "TFViTMAEModel"),
- ("wav2vec2", "TFWav2Vec2Model"),
- ("whisper", "TFWhisperModel"),
- ("xglm", "TFXGLMModel"),
- ("xlm", "TFXLMModel"),
- ("xlm-roberta", "TFXLMRobertaModel"),
- ("xlnet", "TFXLNetModel"),
- ]
- )
- TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
- [
- # Model for pre-training mapping
- ("albert", "TFAlbertForPreTraining"),
- ("bart", "TFBartForConditionalGeneration"),
- ("bert", "TFBertForPreTraining"),
- ("camembert", "TFCamembertForMaskedLM"),
- ("ctrl", "TFCTRLLMHeadModel"),
- ("distilbert", "TFDistilBertForMaskedLM"),
- ("electra", "TFElectraForPreTraining"),
- ("flaubert", "TFFlaubertWithLMHeadModel"),
- ("funnel", "TFFunnelForPreTraining"),
- ("gpt-sw3", "TFGPT2LMHeadModel"),
- ("gpt2", "TFGPT2LMHeadModel"),
- ("idefics", "TFIdeficsForVisionText2Text"),
- ("layoutlm", "TFLayoutLMForMaskedLM"),
- ("lxmert", "TFLxmertForPreTraining"),
- ("mobilebert", "TFMobileBertForPreTraining"),
- ("mpnet", "TFMPNetForMaskedLM"),
- ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
- ("roberta", "TFRobertaForMaskedLM"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
- ("t5", "TFT5ForConditionalGeneration"),
- ("tapas", "TFTapasForMaskedLM"),
- ("transfo-xl", "TFTransfoXLLMHeadModel"),
- ("vit_mae", "TFViTMAEForPreTraining"),
- ("xlm", "TFXLMWithLMHeadModel"),
- ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
- ("xlnet", "TFXLNetLMHeadModel"),
- ]
- )
- TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
- [
- # Model with LM heads mapping
- ("albert", "TFAlbertForMaskedLM"),
- ("bart", "TFBartForConditionalGeneration"),
- ("bert", "TFBertForMaskedLM"),
- ("camembert", "TFCamembertForMaskedLM"),
- ("convbert", "TFConvBertForMaskedLM"),
- ("ctrl", "TFCTRLLMHeadModel"),
- ("distilbert", "TFDistilBertForMaskedLM"),
- ("electra", "TFElectraForMaskedLM"),
- ("esm", "TFEsmForMaskedLM"),
- ("flaubert", "TFFlaubertWithLMHeadModel"),
- ("funnel", "TFFunnelForMaskedLM"),
- ("gpt-sw3", "TFGPT2LMHeadModel"),
- ("gpt2", "TFGPT2LMHeadModel"),
- ("gptj", "TFGPTJForCausalLM"),
- ("layoutlm", "TFLayoutLMForMaskedLM"),
- ("led", "TFLEDForConditionalGeneration"),
- ("longformer", "TFLongformerForMaskedLM"),
- ("marian", "TFMarianMTModel"),
- ("mobilebert", "TFMobileBertForMaskedLM"),
- ("mpnet", "TFMPNetForMaskedLM"),
- ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
- ("rembert", "TFRemBertForMaskedLM"),
- ("roberta", "TFRobertaForMaskedLM"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
- ("roformer", "TFRoFormerForMaskedLM"),
- ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
- ("t5", "TFT5ForConditionalGeneration"),
- ("tapas", "TFTapasForMaskedLM"),
- ("transfo-xl", "TFTransfoXLLMHeadModel"),
- ("whisper", "TFWhisperForConditionalGeneration"),
- ("xlm", "TFXLMWithLMHeadModel"),
- ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
- ("xlnet", "TFXLNetLMHeadModel"),
- ]
- )
- TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Causal LM mapping
- ("bert", "TFBertLMHeadModel"),
- ("camembert", "TFCamembertForCausalLM"),
- ("ctrl", "TFCTRLLMHeadModel"),
- ("gpt-sw3", "TFGPT2LMHeadModel"),
- ("gpt2", "TFGPT2LMHeadModel"),
- ("gptj", "TFGPTJForCausalLM"),
- ("mistral", "TFMistralForCausalLM"),
- ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
- ("opt", "TFOPTForCausalLM"),
- ("rembert", "TFRemBertForCausalLM"),
- ("roberta", "TFRobertaForCausalLM"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
- ("roformer", "TFRoFormerForCausalLM"),
- ("transfo-xl", "TFTransfoXLLMHeadModel"),
- ("xglm", "TFXGLMForCausalLM"),
- ("xlm", "TFXLMWithLMHeadModel"),
- ("xlm-roberta", "TFXLMRobertaForCausalLM"),
- ("xlnet", "TFXLNetLMHeadModel"),
- ]
- )
- TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
- [
- ("deit", "TFDeiTForMaskedImageModeling"),
- ("swin", "TFSwinForMaskedImageModeling"),
- ]
- )
- TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image-classsification
- ("convnext", "TFConvNextForImageClassification"),
- ("convnextv2", "TFConvNextV2ForImageClassification"),
- ("cvt", "TFCvtForImageClassification"),
- ("data2vec-vision", "TFData2VecVisionForImageClassification"),
- ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
- (
- "efficientformer",
- ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
- ),
- ("mobilevit", "TFMobileViTForImageClassification"),
- ("regnet", "TFRegNetForImageClassification"),
- ("resnet", "TFResNetForImageClassification"),
- ("segformer", "TFSegformerForImageClassification"),
- ("swiftformer", "TFSwiftFormerForImageClassification"),
- ("swin", "TFSwinForImageClassification"),
- ("vit", "TFViTForImageClassification"),
- ]
- )
- TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Zero Shot Image Classification mapping
- ("blip", "TFBlipModel"),
- ("clip", "TFCLIPModel"),
- ]
- )
- TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Semantic Segmentation mapping
- ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
- ("mobilevit", "TFMobileViTForSemanticSegmentation"),
- ("segformer", "TFSegformerForSemanticSegmentation"),
- ]
- )
- TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("blip", "TFBlipForConditionalGeneration"),
- ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
- ]
- )
- TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Masked LM mapping
- ("albert", "TFAlbertForMaskedLM"),
- ("bert", "TFBertForMaskedLM"),
- ("camembert", "TFCamembertForMaskedLM"),
- ("convbert", "TFConvBertForMaskedLM"),
- ("deberta", "TFDebertaForMaskedLM"),
- ("deberta-v2", "TFDebertaV2ForMaskedLM"),
- ("distilbert", "TFDistilBertForMaskedLM"),
- ("electra", "TFElectraForMaskedLM"),
- ("esm", "TFEsmForMaskedLM"),
- ("flaubert", "TFFlaubertWithLMHeadModel"),
- ("funnel", "TFFunnelForMaskedLM"),
- ("layoutlm", "TFLayoutLMForMaskedLM"),
- ("longformer", "TFLongformerForMaskedLM"),
- ("mobilebert", "TFMobileBertForMaskedLM"),
- ("mpnet", "TFMPNetForMaskedLM"),
- ("rembert", "TFRemBertForMaskedLM"),
- ("roberta", "TFRobertaForMaskedLM"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
- ("roformer", "TFRoFormerForMaskedLM"),
- ("tapas", "TFTapasForMaskedLM"),
- ("xlm", "TFXLMWithLMHeadModel"),
- ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
- ]
- )
- TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Seq2Seq Causal LM mapping
- ("bart", "TFBartForConditionalGeneration"),
- ("blenderbot", "TFBlenderbotForConditionalGeneration"),
- ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
- ("encoder-decoder", "TFEncoderDecoderModel"),
- ("led", "TFLEDForConditionalGeneration"),
- ("marian", "TFMarianMTModel"),
- ("mbart", "TFMBartForConditionalGeneration"),
- ("mt5", "TFMT5ForConditionalGeneration"),
- ("pegasus", "TFPegasusForConditionalGeneration"),
- ("t5", "TFT5ForConditionalGeneration"),
- ]
- )
- TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
- ("whisper", "TFWhisperForConditionalGeneration"),
- ]
- )
- TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Sequence Classification mapping
- ("albert", "TFAlbertForSequenceClassification"),
- ("bart", "TFBartForSequenceClassification"),
- ("bert", "TFBertForSequenceClassification"),
- ("camembert", "TFCamembertForSequenceClassification"),
- ("convbert", "TFConvBertForSequenceClassification"),
- ("ctrl", "TFCTRLForSequenceClassification"),
- ("deberta", "TFDebertaForSequenceClassification"),
- ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
- ("distilbert", "TFDistilBertForSequenceClassification"),
- ("electra", "TFElectraForSequenceClassification"),
- ("esm", "TFEsmForSequenceClassification"),
- ("flaubert", "TFFlaubertForSequenceClassification"),
- ("funnel", "TFFunnelForSequenceClassification"),
- ("gpt-sw3", "TFGPT2ForSequenceClassification"),
- ("gpt2", "TFGPT2ForSequenceClassification"),
- ("gptj", "TFGPTJForSequenceClassification"),
- ("layoutlm", "TFLayoutLMForSequenceClassification"),
- ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
- ("longformer", "TFLongformerForSequenceClassification"),
- ("mistral", "TFMistralForSequenceClassification"),
- ("mobilebert", "TFMobileBertForSequenceClassification"),
- ("mpnet", "TFMPNetForSequenceClassification"),
- ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
- ("rembert", "TFRemBertForSequenceClassification"),
- ("roberta", "TFRobertaForSequenceClassification"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
- ("roformer", "TFRoFormerForSequenceClassification"),
- ("tapas", "TFTapasForSequenceClassification"),
- ("transfo-xl", "TFTransfoXLForSequenceClassification"),
- ("xlm", "TFXLMForSequenceClassification"),
- ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
- ("xlnet", "TFXLNetForSequenceClassification"),
- ]
- )
- TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Question Answering mapping
- ("albert", "TFAlbertForQuestionAnswering"),
- ("bert", "TFBertForQuestionAnswering"),
- ("camembert", "TFCamembertForQuestionAnswering"),
- ("convbert", "TFConvBertForQuestionAnswering"),
- ("deberta", "TFDebertaForQuestionAnswering"),
- ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
- ("distilbert", "TFDistilBertForQuestionAnswering"),
- ("electra", "TFElectraForQuestionAnswering"),
- ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
- ("funnel", "TFFunnelForQuestionAnswering"),
- ("gptj", "TFGPTJForQuestionAnswering"),
- ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
- ("longformer", "TFLongformerForQuestionAnswering"),
- ("mobilebert", "TFMobileBertForQuestionAnswering"),
- ("mpnet", "TFMPNetForQuestionAnswering"),
- ("rembert", "TFRemBertForQuestionAnswering"),
- ("roberta", "TFRobertaForQuestionAnswering"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
- ("roformer", "TFRoFormerForQuestionAnswering"),
- ("xlm", "TFXLMForQuestionAnsweringSimple"),
- ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
- ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
- ]
- )
- TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
- TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- ("layoutlm", "TFLayoutLMForQuestionAnswering"),
- ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
- ]
- )
- TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Table Question Answering mapping
- ("tapas", "TFTapasForQuestionAnswering"),
- ]
- )
- TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Token Classification mapping
- ("albert", "TFAlbertForTokenClassification"),
- ("bert", "TFBertForTokenClassification"),
- ("camembert", "TFCamembertForTokenClassification"),
- ("convbert", "TFConvBertForTokenClassification"),
- ("deberta", "TFDebertaForTokenClassification"),
- ("deberta-v2", "TFDebertaV2ForTokenClassification"),
- ("distilbert", "TFDistilBertForTokenClassification"),
- ("electra", "TFElectraForTokenClassification"),
- ("esm", "TFEsmForTokenClassification"),
- ("flaubert", "TFFlaubertForTokenClassification"),
- ("funnel", "TFFunnelForTokenClassification"),
- ("layoutlm", "TFLayoutLMForTokenClassification"),
- ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
- ("longformer", "TFLongformerForTokenClassification"),
- ("mobilebert", "TFMobileBertForTokenClassification"),
- ("mpnet", "TFMPNetForTokenClassification"),
- ("rembert", "TFRemBertForTokenClassification"),
- ("roberta", "TFRobertaForTokenClassification"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
- ("roformer", "TFRoFormerForTokenClassification"),
- ("xlm", "TFXLMForTokenClassification"),
- ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
- ("xlnet", "TFXLNetForTokenClassification"),
- ]
- )
- TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Multiple Choice mapping
- ("albert", "TFAlbertForMultipleChoice"),
- ("bert", "TFBertForMultipleChoice"),
- ("camembert", "TFCamembertForMultipleChoice"),
- ("convbert", "TFConvBertForMultipleChoice"),
- ("deberta-v2", "TFDebertaV2ForMultipleChoice"),
- ("distilbert", "TFDistilBertForMultipleChoice"),
- ("electra", "TFElectraForMultipleChoice"),
- ("flaubert", "TFFlaubertForMultipleChoice"),
- ("funnel", "TFFunnelForMultipleChoice"),
- ("longformer", "TFLongformerForMultipleChoice"),
- ("mobilebert", "TFMobileBertForMultipleChoice"),
- ("mpnet", "TFMPNetForMultipleChoice"),
- ("rembert", "TFRemBertForMultipleChoice"),
- ("roberta", "TFRobertaForMultipleChoice"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
- ("roformer", "TFRoFormerForMultipleChoice"),
- ("xlm", "TFXLMForMultipleChoice"),
- ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
- ("xlnet", "TFXLNetForMultipleChoice"),
- ]
- )
- TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("bert", "TFBertForNextSentencePrediction"),
- ("mobilebert", "TFMobileBertForNextSentencePrediction"),
- ]
- )
- TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
- [
- ("sam", "TFSamModel"),
- ]
- )
- TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
- [
- ("albert", "TFAlbertModel"),
- ("bert", "TFBertModel"),
- ("convbert", "TFConvBertModel"),
- ("deberta", "TFDebertaModel"),
- ("deberta-v2", "TFDebertaV2Model"),
- ("distilbert", "TFDistilBertModel"),
- ("electra", "TFElectraModel"),
- ("flaubert", "TFFlaubertModel"),
- ("longformer", "TFLongformerModel"),
- ("mobilebert", "TFMobileBertModel"),
- ("mt5", "TFMT5EncoderModel"),
- ("rembert", "TFRemBertModel"),
- ("roberta", "TFRobertaModel"),
- ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
- ("roformer", "TFRoFormerModel"),
- ("t5", "TFT5EncoderModel"),
- ("xlm", "TFXLMModel"),
- ("xlm-roberta", "TFXLMRobertaModel"),
- ]
- )
- TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
- TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
- TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
- TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
- TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
- )
- TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
- TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
- TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
- )
- TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
- )
- TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
- )
- TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
- )
- TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
- )
- TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
- )
- TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
- )
- TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
- class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
- class TFAutoModelForTextEncoding(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
- class TFAutoModel(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_MAPPING
- TFAutoModel = auto_class_update(TFAutoModel)
- class TFAutoModelForAudioClassification(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
- TFAutoModelForAudioClassification = auto_class_update(
- TFAutoModelForAudioClassification, head_doc="audio classification"
- )
- class TFAutoModelForPreTraining(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
- TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
- # Private on purpose, the public class will add the deprecation warnings.
- class _TFAutoModelWithLMHead(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
- _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
- class TFAutoModelForCausalLM(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
- TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
- class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
- TFAutoModelForMaskedImageModeling = auto_class_update(
- TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
- )
- class TFAutoModelForImageClassification(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
- TFAutoModelForImageClassification = auto_class_update(
- TFAutoModelForImageClassification, head_doc="image classification"
- )
- class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
- TFAutoModelForZeroShotImageClassification = auto_class_update(
- TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
- )
- class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
- TFAutoModelForSemanticSegmentation = auto_class_update(
- TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
- )
- class TFAutoModelForVision2Seq(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
- TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
- class TFAutoModelForMaskedLM(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
- TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
- class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
- TFAutoModelForSeq2SeqLM = auto_class_update(
- TFAutoModelForSeq2SeqLM,
- head_doc="sequence-to-sequence language modeling",
- checkpoint_for_example="google-t5/t5-base",
- )
- class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
- TFAutoModelForSequenceClassification = auto_class_update(
- TFAutoModelForSequenceClassification, head_doc="sequence classification"
- )
- class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
- TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
- class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
- TFAutoModelForDocumentQuestionAnswering = auto_class_update(
- TFAutoModelForDocumentQuestionAnswering,
- head_doc="document question answering",
- checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
- )
- class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
- TFAutoModelForTableQuestionAnswering = auto_class_update(
- TFAutoModelForTableQuestionAnswering,
- head_doc="table question answering",
- checkpoint_for_example="google/tapas-base-finetuned-wtq",
- )
- class TFAutoModelForTokenClassification(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
- TFAutoModelForTokenClassification = auto_class_update(
- TFAutoModelForTokenClassification, head_doc="token classification"
- )
- class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
- TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
- class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
- TFAutoModelForNextSentencePrediction = auto_class_update(
- TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
- )
- class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
- _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
- TFAutoModelForSpeechSeq2Seq = auto_class_update(
- TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
- )
- class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
- @classmethod
- def from_config(cls, config):
- warnings.warn(
- "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
- " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
- " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_config(config)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- warnings.warn(
- "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
- " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
- " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
- FutureWarning,
- )
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|