| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- # coding=utf-8
- # Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Auto Model class."""
- from collections import OrderedDict
- from ...utils import logging
- from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
- from .configuration_auto import CONFIG_MAPPING_NAMES
- logger = logging.get_logger(__name__)
- FLAX_MODEL_MAPPING_NAMES = OrderedDict(
- [
- # Base model mapping
- ("albert", "FlaxAlbertModel"),
- ("bart", "FlaxBartModel"),
- ("beit", "FlaxBeitModel"),
- ("bert", "FlaxBertModel"),
- ("big_bird", "FlaxBigBirdModel"),
- ("blenderbot", "FlaxBlenderbotModel"),
- ("blenderbot-small", "FlaxBlenderbotSmallModel"),
- ("bloom", "FlaxBloomModel"),
- ("clip", "FlaxCLIPModel"),
- ("dinov2", "FlaxDinov2Model"),
- ("distilbert", "FlaxDistilBertModel"),
- ("electra", "FlaxElectraModel"),
- ("gemma", "FlaxGemmaModel"),
- ("gpt-sw3", "FlaxGPT2Model"),
- ("gpt2", "FlaxGPT2Model"),
- ("gpt_neo", "FlaxGPTNeoModel"),
- ("gptj", "FlaxGPTJModel"),
- ("llama", "FlaxLlamaModel"),
- ("longt5", "FlaxLongT5Model"),
- ("marian", "FlaxMarianModel"),
- ("mbart", "FlaxMBartModel"),
- ("mistral", "FlaxMistralModel"),
- ("mt5", "FlaxMT5Model"),
- ("opt", "FlaxOPTModel"),
- ("pegasus", "FlaxPegasusModel"),
- ("regnet", "FlaxRegNetModel"),
- ("resnet", "FlaxResNetModel"),
- ("roberta", "FlaxRobertaModel"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
- ("roformer", "FlaxRoFormerModel"),
- ("t5", "FlaxT5Model"),
- ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
- ("vit", "FlaxViTModel"),
- ("wav2vec2", "FlaxWav2Vec2Model"),
- ("whisper", "FlaxWhisperModel"),
- ("xglm", "FlaxXGLMModel"),
- ("xlm-roberta", "FlaxXLMRobertaModel"),
- ]
- )
- FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
- [
- # Model for pre-training mapping
- ("albert", "FlaxAlbertForPreTraining"),
- ("bart", "FlaxBartForConditionalGeneration"),
- ("bert", "FlaxBertForPreTraining"),
- ("big_bird", "FlaxBigBirdForPreTraining"),
- ("electra", "FlaxElectraForPreTraining"),
- ("longt5", "FlaxLongT5ForConditionalGeneration"),
- ("mbart", "FlaxMBartForConditionalGeneration"),
- ("mt5", "FlaxMT5ForConditionalGeneration"),
- ("roberta", "FlaxRobertaForMaskedLM"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
- ("roformer", "FlaxRoFormerForMaskedLM"),
- ("t5", "FlaxT5ForConditionalGeneration"),
- ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
- ("whisper", "FlaxWhisperForConditionalGeneration"),
- ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
- ]
- )
- FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Masked LM mapping
- ("albert", "FlaxAlbertForMaskedLM"),
- ("bart", "FlaxBartForConditionalGeneration"),
- ("bert", "FlaxBertForMaskedLM"),
- ("big_bird", "FlaxBigBirdForMaskedLM"),
- ("distilbert", "FlaxDistilBertForMaskedLM"),
- ("electra", "FlaxElectraForMaskedLM"),
- ("mbart", "FlaxMBartForConditionalGeneration"),
- ("roberta", "FlaxRobertaForMaskedLM"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
- ("roformer", "FlaxRoFormerForMaskedLM"),
- ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
- ]
- )
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Seq2Seq Causal LM mapping
- ("bart", "FlaxBartForConditionalGeneration"),
- ("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
- ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
- ("encoder-decoder", "FlaxEncoderDecoderModel"),
- ("longt5", "FlaxLongT5ForConditionalGeneration"),
- ("marian", "FlaxMarianMTModel"),
- ("mbart", "FlaxMBartForConditionalGeneration"),
- ("mt5", "FlaxMT5ForConditionalGeneration"),
- ("pegasus", "FlaxPegasusForConditionalGeneration"),
- ("t5", "FlaxT5ForConditionalGeneration"),
- ]
- )
- FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Image-classsification
- ("beit", "FlaxBeitForImageClassification"),
- ("dinov2", "FlaxDinov2ForImageClassification"),
- ("regnet", "FlaxRegNetForImageClassification"),
- ("resnet", "FlaxResNetForImageClassification"),
- ("vit", "FlaxViTForImageClassification"),
- ]
- )
- FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
- ]
- )
- FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
- [
- # Model for Causal LM mapping
- ("bart", "FlaxBartForCausalLM"),
- ("bert", "FlaxBertForCausalLM"),
- ("big_bird", "FlaxBigBirdForCausalLM"),
- ("bloom", "FlaxBloomForCausalLM"),
- ("electra", "FlaxElectraForCausalLM"),
- ("gemma", "FlaxGemmaForCausalLM"),
- ("gpt-sw3", "FlaxGPT2LMHeadModel"),
- ("gpt2", "FlaxGPT2LMHeadModel"),
- ("gpt_neo", "FlaxGPTNeoForCausalLM"),
- ("gptj", "FlaxGPTJForCausalLM"),
- ("llama", "FlaxLlamaForCausalLM"),
- ("mistral", "FlaxMistralForCausalLM"),
- ("opt", "FlaxOPTForCausalLM"),
- ("roberta", "FlaxRobertaForCausalLM"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
- ("xglm", "FlaxXGLMForCausalLM"),
- ("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
- ]
- )
- FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Sequence Classification mapping
- ("albert", "FlaxAlbertForSequenceClassification"),
- ("bart", "FlaxBartForSequenceClassification"),
- ("bert", "FlaxBertForSequenceClassification"),
- ("big_bird", "FlaxBigBirdForSequenceClassification"),
- ("distilbert", "FlaxDistilBertForSequenceClassification"),
- ("electra", "FlaxElectraForSequenceClassification"),
- ("mbart", "FlaxMBartForSequenceClassification"),
- ("roberta", "FlaxRobertaForSequenceClassification"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
- ("roformer", "FlaxRoFormerForSequenceClassification"),
- ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
- ]
- )
- FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
- [
- # Model for Question Answering mapping
- ("albert", "FlaxAlbertForQuestionAnswering"),
- ("bart", "FlaxBartForQuestionAnswering"),
- ("bert", "FlaxBertForQuestionAnswering"),
- ("big_bird", "FlaxBigBirdForQuestionAnswering"),
- ("distilbert", "FlaxDistilBertForQuestionAnswering"),
- ("electra", "FlaxElectraForQuestionAnswering"),
- ("mbart", "FlaxMBartForQuestionAnswering"),
- ("roberta", "FlaxRobertaForQuestionAnswering"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
- ("roformer", "FlaxRoFormerForQuestionAnswering"),
- ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
- ]
- )
- FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- # Model for Token Classification mapping
- ("albert", "FlaxAlbertForTokenClassification"),
- ("bert", "FlaxBertForTokenClassification"),
- ("big_bird", "FlaxBigBirdForTokenClassification"),
- ("distilbert", "FlaxDistilBertForTokenClassification"),
- ("electra", "FlaxElectraForTokenClassification"),
- ("roberta", "FlaxRobertaForTokenClassification"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
- ("roformer", "FlaxRoFormerForTokenClassification"),
- ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
- ]
- )
- FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
- [
- # Model for Multiple Choice mapping
- ("albert", "FlaxAlbertForMultipleChoice"),
- ("bert", "FlaxBertForMultipleChoice"),
- ("big_bird", "FlaxBigBirdForMultipleChoice"),
- ("distilbert", "FlaxDistilBertForMultipleChoice"),
- ("electra", "FlaxElectraForMultipleChoice"),
- ("roberta", "FlaxRobertaForMultipleChoice"),
- ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
- ("roformer", "FlaxRoFormerForMultipleChoice"),
- ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
- ]
- )
- FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
- [
- ("bert", "FlaxBertForNextSentencePrediction"),
- ]
- )
- FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
- [
- ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
- ("whisper", "FlaxWhisperForConditionalGeneration"),
- ]
- )
- FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
- [
- ("whisper", "FlaxWhisperForAudioClassification"),
- ]
- )
- FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
- FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
- FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
- FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
- FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
- )
- FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
- CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
- )
- class FlaxAutoModel(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_MAPPING
- FlaxAutoModel = auto_class_update(FlaxAutoModel)
- class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
- FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
- class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
- FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
- class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
- FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
- class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
- FlaxAutoModelForSeq2SeqLM = auto_class_update(
- FlaxAutoModelForSeq2SeqLM,
- head_doc="sequence-to-sequence language modeling",
- checkpoint_for_example="google-t5/t5-base",
- )
- class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
- FlaxAutoModelForSequenceClassification = auto_class_update(
- FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
- )
- class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
- FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
- class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
- FlaxAutoModelForTokenClassification = auto_class_update(
- FlaxAutoModelForTokenClassification, head_doc="token classification"
- )
- class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
- FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
- class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
- FlaxAutoModelForNextSentencePrediction = auto_class_update(
- FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
- )
- class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
- FlaxAutoModelForImageClassification = auto_class_update(
- FlaxAutoModelForImageClassification, head_doc="image classification"
- )
- class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
- FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
- class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
- _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
- FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
- FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
- )
|