modeling_flax_auto.py 14 KB


  1. # coding=utf-8
  2. # Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Auto Model class."""
  16. from collections import OrderedDict
  17. from ...utils import logging
  18. from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
  19. from .configuration_auto import CONFIG_MAPPING_NAMES
  20. logger = logging.get_logger(__name__)
  21. FLAX_MODEL_MAPPING_NAMES = OrderedDict(
  22. [
  23. # Base model mapping
  24. ("albert", "FlaxAlbertModel"),
  25. ("bart", "FlaxBartModel"),
  26. ("beit", "FlaxBeitModel"),
  27. ("bert", "FlaxBertModel"),
  28. ("big_bird", "FlaxBigBirdModel"),
  29. ("blenderbot", "FlaxBlenderbotModel"),
  30. ("blenderbot-small", "FlaxBlenderbotSmallModel"),
  31. ("bloom", "FlaxBloomModel"),
  32. ("clip", "FlaxCLIPModel"),
  33. ("dinov2", "FlaxDinov2Model"),
  34. ("distilbert", "FlaxDistilBertModel"),
  35. ("electra", "FlaxElectraModel"),
  36. ("gemma", "FlaxGemmaModel"),
  37. ("gpt-sw3", "FlaxGPT2Model"),
  38. ("gpt2", "FlaxGPT2Model"),
  39. ("gpt_neo", "FlaxGPTNeoModel"),
  40. ("gptj", "FlaxGPTJModel"),
  41. ("llama", "FlaxLlamaModel"),
  42. ("longt5", "FlaxLongT5Model"),
  43. ("marian", "FlaxMarianModel"),
  44. ("mbart", "FlaxMBartModel"),
  45. ("mistral", "FlaxMistralModel"),
  46. ("mt5", "FlaxMT5Model"),
  47. ("opt", "FlaxOPTModel"),
  48. ("pegasus", "FlaxPegasusModel"),
  49. ("regnet", "FlaxRegNetModel"),
  50. ("resnet", "FlaxResNetModel"),
  51. ("roberta", "FlaxRobertaModel"),
  52. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
  53. ("roformer", "FlaxRoFormerModel"),
  54. ("t5", "FlaxT5Model"),
  55. ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
  56. ("vit", "FlaxViTModel"),
  57. ("wav2vec2", "FlaxWav2Vec2Model"),
  58. ("whisper", "FlaxWhisperModel"),
  59. ("xglm", "FlaxXGLMModel"),
  60. ("xlm-roberta", "FlaxXLMRobertaModel"),
  61. ]
  62. )
  63. FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
  64. [
  65. # Model for pre-training mapping
  66. ("albert", "FlaxAlbertForPreTraining"),
  67. ("bart", "FlaxBartForConditionalGeneration"),
  68. ("bert", "FlaxBertForPreTraining"),
  69. ("big_bird", "FlaxBigBirdForPreTraining"),
  70. ("electra", "FlaxElectraForPreTraining"),
  71. ("longt5", "FlaxLongT5ForConditionalGeneration"),
  72. ("mbart", "FlaxMBartForConditionalGeneration"),
  73. ("mt5", "FlaxMT5ForConditionalGeneration"),
  74. ("roberta", "FlaxRobertaForMaskedLM"),
  75. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
  76. ("roformer", "FlaxRoFormerForMaskedLM"),
  77. ("t5", "FlaxT5ForConditionalGeneration"),
  78. ("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
  79. ("whisper", "FlaxWhisperForConditionalGeneration"),
  80. ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
  81. ]
  82. )
  83. FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
  84. [
  85. # Model for Masked LM mapping
  86. ("albert", "FlaxAlbertForMaskedLM"),
  87. ("bart", "FlaxBartForConditionalGeneration"),
  88. ("bert", "FlaxBertForMaskedLM"),
  89. ("big_bird", "FlaxBigBirdForMaskedLM"),
  90. ("distilbert", "FlaxDistilBertForMaskedLM"),
  91. ("electra", "FlaxElectraForMaskedLM"),
  92. ("mbart", "FlaxMBartForConditionalGeneration"),
  93. ("roberta", "FlaxRobertaForMaskedLM"),
  94. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMaskedLM"),
  95. ("roformer", "FlaxRoFormerForMaskedLM"),
  96. ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
  97. ]
  98. )
  99. FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  100. [
  101. # Model for Seq2Seq Causal LM mapping
  102. ("bart", "FlaxBartForConditionalGeneration"),
  103. ("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
  104. ("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
  105. ("encoder-decoder", "FlaxEncoderDecoderModel"),
  106. ("longt5", "FlaxLongT5ForConditionalGeneration"),
  107. ("marian", "FlaxMarianMTModel"),
  108. ("mbart", "FlaxMBartForConditionalGeneration"),
  109. ("mt5", "FlaxMT5ForConditionalGeneration"),
  110. ("pegasus", "FlaxPegasusForConditionalGeneration"),
  111. ("t5", "FlaxT5ForConditionalGeneration"),
  112. ]
  113. )
  114. FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  115. [
  116. # Model for Image-classsification
  117. ("beit", "FlaxBeitForImageClassification"),
  118. ("dinov2", "FlaxDinov2ForImageClassification"),
  119. ("regnet", "FlaxRegNetForImageClassification"),
  120. ("resnet", "FlaxResNetForImageClassification"),
  121. ("vit", "FlaxViTForImageClassification"),
  122. ]
  123. )
  124. FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
  125. [
  126. ("vision-encoder-decoder", "FlaxVisionEncoderDecoderModel"),
  127. ]
  128. )
  129. FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  130. [
  131. # Model for Causal LM mapping
  132. ("bart", "FlaxBartForCausalLM"),
  133. ("bert", "FlaxBertForCausalLM"),
  134. ("big_bird", "FlaxBigBirdForCausalLM"),
  135. ("bloom", "FlaxBloomForCausalLM"),
  136. ("electra", "FlaxElectraForCausalLM"),
  137. ("gemma", "FlaxGemmaForCausalLM"),
  138. ("gpt-sw3", "FlaxGPT2LMHeadModel"),
  139. ("gpt2", "FlaxGPT2LMHeadModel"),
  140. ("gpt_neo", "FlaxGPTNeoForCausalLM"),
  141. ("gptj", "FlaxGPTJForCausalLM"),
  142. ("llama", "FlaxLlamaForCausalLM"),
  143. ("mistral", "FlaxMistralForCausalLM"),
  144. ("opt", "FlaxOPTForCausalLM"),
  145. ("roberta", "FlaxRobertaForCausalLM"),
  146. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"),
  147. ("xglm", "FlaxXGLMForCausalLM"),
  148. ("xlm-roberta", "FlaxXLMRobertaForCausalLM"),
  149. ]
  150. )
  151. FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  152. [
  153. # Model for Sequence Classification mapping
  154. ("albert", "FlaxAlbertForSequenceClassification"),
  155. ("bart", "FlaxBartForSequenceClassification"),
  156. ("bert", "FlaxBertForSequenceClassification"),
  157. ("big_bird", "FlaxBigBirdForSequenceClassification"),
  158. ("distilbert", "FlaxDistilBertForSequenceClassification"),
  159. ("electra", "FlaxElectraForSequenceClassification"),
  160. ("mbart", "FlaxMBartForSequenceClassification"),
  161. ("roberta", "FlaxRobertaForSequenceClassification"),
  162. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForSequenceClassification"),
  163. ("roformer", "FlaxRoFormerForSequenceClassification"),
  164. ("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
  165. ]
  166. )
  167. FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  168. [
  169. # Model for Question Answering mapping
  170. ("albert", "FlaxAlbertForQuestionAnswering"),
  171. ("bart", "FlaxBartForQuestionAnswering"),
  172. ("bert", "FlaxBertForQuestionAnswering"),
  173. ("big_bird", "FlaxBigBirdForQuestionAnswering"),
  174. ("distilbert", "FlaxDistilBertForQuestionAnswering"),
  175. ("electra", "FlaxElectraForQuestionAnswering"),
  176. ("mbart", "FlaxMBartForQuestionAnswering"),
  177. ("roberta", "FlaxRobertaForQuestionAnswering"),
  178. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForQuestionAnswering"),
  179. ("roformer", "FlaxRoFormerForQuestionAnswering"),
  180. ("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
  181. ]
  182. )
  183. FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  184. [
  185. # Model for Token Classification mapping
  186. ("albert", "FlaxAlbertForTokenClassification"),
  187. ("bert", "FlaxBertForTokenClassification"),
  188. ("big_bird", "FlaxBigBirdForTokenClassification"),
  189. ("distilbert", "FlaxDistilBertForTokenClassification"),
  190. ("electra", "FlaxElectraForTokenClassification"),
  191. ("roberta", "FlaxRobertaForTokenClassification"),
  192. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForTokenClassification"),
  193. ("roformer", "FlaxRoFormerForTokenClassification"),
  194. ("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
  195. ]
  196. )
  197. FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
  198. [
  199. # Model for Multiple Choice mapping
  200. ("albert", "FlaxAlbertForMultipleChoice"),
  201. ("bert", "FlaxBertForMultipleChoice"),
  202. ("big_bird", "FlaxBigBirdForMultipleChoice"),
  203. ("distilbert", "FlaxDistilBertForMultipleChoice"),
  204. ("electra", "FlaxElectraForMultipleChoice"),
  205. ("roberta", "FlaxRobertaForMultipleChoice"),
  206. ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForMultipleChoice"),
  207. ("roformer", "FlaxRoFormerForMultipleChoice"),
  208. ("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
  209. ]
  210. )
  211. FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
  212. [
  213. ("bert", "FlaxBertForNextSentencePrediction"),
  214. ]
  215. )
  216. FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
  217. [
  218. ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"),
  219. ("whisper", "FlaxWhisperForConditionalGeneration"),
  220. ]
  221. )
  222. FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  223. [
  224. ("whisper", "FlaxWhisperForAudioClassification"),
  225. ]
  226. )
  227. FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
  228. FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
  229. FLAX_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
  230. FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
  231. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  232. )
  233. FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  234. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
  235. )
  236. FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
  237. FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  238. FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  239. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
  240. )
  241. FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  242. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
  243. )
  244. FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  245. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  246. )
  247. FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
  248. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
  249. )
  250. FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
  251. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
  252. )
  253. FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
  254. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
  255. )
  256. FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  257. CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
  258. )
  259. class FlaxAutoModel(_BaseAutoModelClass):
  260. _model_mapping = FLAX_MODEL_MAPPING
  261. FlaxAutoModel = auto_class_update(FlaxAutoModel)
  262. class FlaxAutoModelForPreTraining(_BaseAutoModelClass):
  263. _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
  264. FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
  265. class FlaxAutoModelForCausalLM(_BaseAutoModelClass):
  266. _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
  267. FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
  268. class FlaxAutoModelForMaskedLM(_BaseAutoModelClass):
  269. _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
  270. FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
  271. class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass):
  272. _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
  273. FlaxAutoModelForSeq2SeqLM = auto_class_update(
  274. FlaxAutoModelForSeq2SeqLM,
  275. head_doc="sequence-to-sequence language modeling",
  276. checkpoint_for_example="google-t5/t5-base",
  277. )
  278. class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass):
  279. _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
  280. FlaxAutoModelForSequenceClassification = auto_class_update(
  281. FlaxAutoModelForSequenceClassification, head_doc="sequence classification"
  282. )
  283. class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass):
  284. _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
  285. FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
  286. class FlaxAutoModelForTokenClassification(_BaseAutoModelClass):
  287. _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
  288. FlaxAutoModelForTokenClassification = auto_class_update(
  289. FlaxAutoModelForTokenClassification, head_doc="token classification"
  290. )
  291. class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass):
  292. _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
  293. FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
  294. class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass):
  295. _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
  296. FlaxAutoModelForNextSentencePrediction = auto_class_update(
  297. FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
  298. )
  299. class FlaxAutoModelForImageClassification(_BaseAutoModelClass):
  300. _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
  301. FlaxAutoModelForImageClassification = auto_class_update(
  302. FlaxAutoModelForImageClassification, head_doc="image classification"
  303. )
  304. class FlaxAutoModelForVision2Seq(_BaseAutoModelClass):
  305. _model_mapping = FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING
  306. FlaxAutoModelForVision2Seq = auto_class_update(FlaxAutoModelForVision2Seq, head_doc="vision-to-text modeling")
  307. class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
  308. _model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
  309. FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
  310. FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
  311. )