modeling_tf_auto.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Auto Model class."""
  16. import warnings
  17. from collections import OrderedDict
  18. from ...utils import logging
  19. from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
  20. from .configuration_auto import CONFIG_MAPPING_NAMES
  21. logger = logging.get_logger(__name__)
  22. TF_MODEL_MAPPING_NAMES = OrderedDict(
  23. [
  24. # Base model mapping
  25. ("albert", "TFAlbertModel"),
  26. ("bart", "TFBartModel"),
  27. ("bert", "TFBertModel"),
  28. ("blenderbot", "TFBlenderbotModel"),
  29. ("blenderbot-small", "TFBlenderbotSmallModel"),
  30. ("blip", "TFBlipModel"),
  31. ("camembert", "TFCamembertModel"),
  32. ("clip", "TFCLIPModel"),
  33. ("convbert", "TFConvBertModel"),
  34. ("convnext", "TFConvNextModel"),
  35. ("convnextv2", "TFConvNextV2Model"),
  36. ("ctrl", "TFCTRLModel"),
  37. ("cvt", "TFCvtModel"),
  38. ("data2vec-vision", "TFData2VecVisionModel"),
  39. ("deberta", "TFDebertaModel"),
  40. ("deberta-v2", "TFDebertaV2Model"),
  41. ("deit", "TFDeiTModel"),
  42. ("distilbert", "TFDistilBertModel"),
  43. ("dpr", "TFDPRQuestionEncoder"),
  44. ("efficientformer", "TFEfficientFormerModel"),
  45. ("electra", "TFElectraModel"),
  46. ("esm", "TFEsmModel"),
  47. ("flaubert", "TFFlaubertModel"),
  48. ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
  49. ("gpt-sw3", "TFGPT2Model"),
  50. ("gpt2", "TFGPT2Model"),
  51. ("gptj", "TFGPTJModel"),
  52. ("groupvit", "TFGroupViTModel"),
  53. ("hubert", "TFHubertModel"),
  54. ("idefics", "TFIdeficsModel"),
  55. ("layoutlm", "TFLayoutLMModel"),
  56. ("layoutlmv3", "TFLayoutLMv3Model"),
  57. ("led", "TFLEDModel"),
  58. ("longformer", "TFLongformerModel"),
  59. ("lxmert", "TFLxmertModel"),
  60. ("marian", "TFMarianModel"),
  61. ("mbart", "TFMBartModel"),
  62. ("mistral", "TFMistralModel"),
  63. ("mobilebert", "TFMobileBertModel"),
  64. ("mobilevit", "TFMobileViTModel"),
  65. ("mpnet", "TFMPNetModel"),
  66. ("mt5", "TFMT5Model"),
  67. ("openai-gpt", "TFOpenAIGPTModel"),
  68. ("opt", "TFOPTModel"),
  69. ("pegasus", "TFPegasusModel"),
  70. ("regnet", "TFRegNetModel"),
  71. ("rembert", "TFRemBertModel"),
  72. ("resnet", "TFResNetModel"),
  73. ("roberta", "TFRobertaModel"),
  74. ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
  75. ("roformer", "TFRoFormerModel"),
  76. ("sam", "TFSamModel"),
  77. ("segformer", "TFSegformerModel"),
  78. ("speech_to_text", "TFSpeech2TextModel"),
  79. ("swiftformer", "TFSwiftFormerModel"),
  80. ("swin", "TFSwinModel"),
  81. ("t5", "TFT5Model"),
  82. ("tapas", "TFTapasModel"),
  83. ("transfo-xl", "TFTransfoXLModel"),
  84. ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
  85. ("vit", "TFViTModel"),
  86. ("vit_mae", "TFViTMAEModel"),
  87. ("wav2vec2", "TFWav2Vec2Model"),
  88. ("whisper", "TFWhisperModel"),
  89. ("xglm", "TFXGLMModel"),
  90. ("xlm", "TFXLMModel"),
  91. ("xlm-roberta", "TFXLMRobertaModel"),
  92. ("xlnet", "TFXLNetModel"),
  93. ]
  94. )
  95. TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
  96. [
  97. # Model for pre-training mapping
  98. ("albert", "TFAlbertForPreTraining"),
  99. ("bart", "TFBartForConditionalGeneration"),
  100. ("bert", "TFBertForPreTraining"),
  101. ("camembert", "TFCamembertForMaskedLM"),
  102. ("ctrl", "TFCTRLLMHeadModel"),
  103. ("distilbert", "TFDistilBertForMaskedLM"),
  104. ("electra", "TFElectraForPreTraining"),
  105. ("flaubert", "TFFlaubertWithLMHeadModel"),
  106. ("funnel", "TFFunnelForPreTraining"),
  107. ("gpt-sw3", "TFGPT2LMHeadModel"),
  108. ("gpt2", "TFGPT2LMHeadModel"),
  109. ("idefics", "TFIdeficsForVisionText2Text"),
  110. ("layoutlm", "TFLayoutLMForMaskedLM"),
  111. ("lxmert", "TFLxmertForPreTraining"),
  112. ("mobilebert", "TFMobileBertForPreTraining"),
  113. ("mpnet", "TFMPNetForMaskedLM"),
  114. ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
  115. ("roberta", "TFRobertaForMaskedLM"),
  116. ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
  117. ("t5", "TFT5ForConditionalGeneration"),
  118. ("tapas", "TFTapasForMaskedLM"),
  119. ("transfo-xl", "TFTransfoXLLMHeadModel"),
  120. ("vit_mae", "TFViTMAEForPreTraining"),
  121. ("xlm", "TFXLMWithLMHeadModel"),
  122. ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
  123. ("xlnet", "TFXLNetLMHeadModel"),
  124. ]
  125. )
  126. TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
  127. [
  128. # Model with LM heads mapping
  129. ("albert", "TFAlbertForMaskedLM"),
  130. ("bart", "TFBartForConditionalGeneration"),
  131. ("bert", "TFBertForMaskedLM"),
  132. ("camembert", "TFCamembertForMaskedLM"),
  133. ("convbert", "TFConvBertForMaskedLM"),
  134. ("ctrl", "TFCTRLLMHeadModel"),
  135. ("distilbert", "TFDistilBertForMaskedLM"),
  136. ("electra", "TFElectraForMaskedLM"),
  137. ("esm", "TFEsmForMaskedLM"),
  138. ("flaubert", "TFFlaubertWithLMHeadModel"),
  139. ("funnel", "TFFunnelForMaskedLM"),
  140. ("gpt-sw3", "TFGPT2LMHeadModel"),
  141. ("gpt2", "TFGPT2LMHeadModel"),
  142. ("gptj", "TFGPTJForCausalLM"),
  143. ("layoutlm", "TFLayoutLMForMaskedLM"),
  144. ("led", "TFLEDForConditionalGeneration"),
  145. ("longformer", "TFLongformerForMaskedLM"),
  146. ("marian", "TFMarianMTModel"),
  147. ("mobilebert", "TFMobileBertForMaskedLM"),
  148. ("mpnet", "TFMPNetForMaskedLM"),
  149. ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
  150. ("rembert", "TFRemBertForMaskedLM"),
  151. ("roberta", "TFRobertaForMaskedLM"),
  152. ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
  153. ("roformer", "TFRoFormerForMaskedLM"),
  154. ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
  155. ("t5", "TFT5ForConditionalGeneration"),
  156. ("tapas", "TFTapasForMaskedLM"),
  157. ("transfo-xl", "TFTransfoXLLMHeadModel"),
  158. ("whisper", "TFWhisperForConditionalGeneration"),
  159. ("xlm", "TFXLMWithLMHeadModel"),
  160. ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
  161. ("xlnet", "TFXLNetLMHeadModel"),
  162. ]
  163. )
  164. TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  165. [
  166. # Model for Causal LM mapping
  167. ("bert", "TFBertLMHeadModel"),
  168. ("camembert", "TFCamembertForCausalLM"),
  169. ("ctrl", "TFCTRLLMHeadModel"),
  170. ("gpt-sw3", "TFGPT2LMHeadModel"),
  171. ("gpt2", "TFGPT2LMHeadModel"),
  172. ("gptj", "TFGPTJForCausalLM"),
  173. ("mistral", "TFMistralForCausalLM"),
  174. ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
  175. ("opt", "TFOPTForCausalLM"),
  176. ("rembert", "TFRemBertForCausalLM"),
  177. ("roberta", "TFRobertaForCausalLM"),
  178. ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
  179. ("roformer", "TFRoFormerForCausalLM"),
  180. ("transfo-xl", "TFTransfoXLLMHeadModel"),
  181. ("xglm", "TFXGLMForCausalLM"),
  182. ("xlm", "TFXLMWithLMHeadModel"),
  183. ("xlm-roberta", "TFXLMRobertaForCausalLM"),
  184. ("xlnet", "TFXLNetLMHeadModel"),
  185. ]
  186. )
  187. TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
  188. [
  189. ("deit", "TFDeiTForMaskedImageModeling"),
  190. ("swin", "TFSwinForMaskedImageModeling"),
  191. ]
  192. )
  193. TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  194. [
  195. # Model for Image-classsification
  196. ("convnext", "TFConvNextForImageClassification"),
  197. ("convnextv2", "TFConvNextV2ForImageClassification"),
  198. ("cvt", "TFCvtForImageClassification"),
  199. ("data2vec-vision", "TFData2VecVisionForImageClassification"),
  200. ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
  201. (
  202. "efficientformer",
  203. ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
  204. ),
  205. ("mobilevit", "TFMobileViTForImageClassification"),
  206. ("regnet", "TFRegNetForImageClassification"),
  207. ("resnet", "TFResNetForImageClassification"),
  208. ("segformer", "TFSegformerForImageClassification"),
  209. ("swiftformer", "TFSwiftFormerForImageClassification"),
  210. ("swin", "TFSwinForImageClassification"),
  211. ("vit", "TFViTForImageClassification"),
  212. ]
  213. )
  214. TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  215. [
  216. # Model for Zero Shot Image Classification mapping
  217. ("blip", "TFBlipModel"),
  218. ("clip", "TFCLIPModel"),
  219. ]
  220. )
  221. TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
  222. [
  223. # Model for Semantic Segmentation mapping
  224. ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
  225. ("mobilevit", "TFMobileViTForSemanticSegmentation"),
  226. ("segformer", "TFSegformerForSemanticSegmentation"),
  227. ]
  228. )
  229. TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
  230. [
  231. ("blip", "TFBlipForConditionalGeneration"),
  232. ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
  233. ]
  234. )
  235. TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
  236. [
  237. # Model for Masked LM mapping
  238. ("albert", "TFAlbertForMaskedLM"),
  239. ("bert", "TFBertForMaskedLM"),
  240. ("camembert", "TFCamembertForMaskedLM"),
  241. ("convbert", "TFConvBertForMaskedLM"),
  242. ("deberta", "TFDebertaForMaskedLM"),
  243. ("deberta-v2", "TFDebertaV2ForMaskedLM"),
  244. ("distilbert", "TFDistilBertForMaskedLM"),
  245. ("electra", "TFElectraForMaskedLM"),
  246. ("esm", "TFEsmForMaskedLM"),
  247. ("flaubert", "TFFlaubertWithLMHeadModel"),
  248. ("funnel", "TFFunnelForMaskedLM"),
  249. ("layoutlm", "TFLayoutLMForMaskedLM"),
  250. ("longformer", "TFLongformerForMaskedLM"),
  251. ("mobilebert", "TFMobileBertForMaskedLM"),
  252. ("mpnet", "TFMPNetForMaskedLM"),
  253. ("rembert", "TFRemBertForMaskedLM"),
  254. ("roberta", "TFRobertaForMaskedLM"),
  255. ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
  256. ("roformer", "TFRoFormerForMaskedLM"),
  257. ("tapas", "TFTapasForMaskedLM"),
  258. ("xlm", "TFXLMWithLMHeadModel"),
  259. ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
  260. ]
  261. )
  262. TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
  263. [
  264. # Model for Seq2Seq Causal LM mapping
  265. ("bart", "TFBartForConditionalGeneration"),
  266. ("blenderbot", "TFBlenderbotForConditionalGeneration"),
  267. ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
  268. ("encoder-decoder", "TFEncoderDecoderModel"),
  269. ("led", "TFLEDForConditionalGeneration"),
  270. ("marian", "TFMarianMTModel"),
  271. ("mbart", "TFMBartForConditionalGeneration"),
  272. ("mt5", "TFMT5ForConditionalGeneration"),
  273. ("pegasus", "TFPegasusForConditionalGeneration"),
  274. ("t5", "TFT5ForConditionalGeneration"),
  275. ]
  276. )
  277. TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
  278. [
  279. ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
  280. ("whisper", "TFWhisperForConditionalGeneration"),
  281. ]
  282. )
  283. TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  284. [
  285. # Model for Sequence Classification mapping
  286. ("albert", "TFAlbertForSequenceClassification"),
  287. ("bart", "TFBartForSequenceClassification"),
  288. ("bert", "TFBertForSequenceClassification"),
  289. ("camembert", "TFCamembertForSequenceClassification"),
  290. ("convbert", "TFConvBertForSequenceClassification"),
  291. ("ctrl", "TFCTRLForSequenceClassification"),
  292. ("deberta", "TFDebertaForSequenceClassification"),
  293. ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
  294. ("distilbert", "TFDistilBertForSequenceClassification"),
  295. ("electra", "TFElectraForSequenceClassification"),
  296. ("esm", "TFEsmForSequenceClassification"),
  297. ("flaubert", "TFFlaubertForSequenceClassification"),
  298. ("funnel", "TFFunnelForSequenceClassification"),
  299. ("gpt-sw3", "TFGPT2ForSequenceClassification"),
  300. ("gpt2", "TFGPT2ForSequenceClassification"),
  301. ("gptj", "TFGPTJForSequenceClassification"),
  302. ("layoutlm", "TFLayoutLMForSequenceClassification"),
  303. ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
  304. ("longformer", "TFLongformerForSequenceClassification"),
  305. ("mistral", "TFMistralForSequenceClassification"),
  306. ("mobilebert", "TFMobileBertForSequenceClassification"),
  307. ("mpnet", "TFMPNetForSequenceClassification"),
  308. ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
  309. ("rembert", "TFRemBertForSequenceClassification"),
  310. ("roberta", "TFRobertaForSequenceClassification"),
  311. ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
  312. ("roformer", "TFRoFormerForSequenceClassification"),
  313. ("tapas", "TFTapasForSequenceClassification"),
  314. ("transfo-xl", "TFTransfoXLForSequenceClassification"),
  315. ("xlm", "TFXLMForSequenceClassification"),
  316. ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
  317. ("xlnet", "TFXLNetForSequenceClassification"),
  318. ]
  319. )
  320. TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  321. [
  322. # Model for Question Answering mapping
  323. ("albert", "TFAlbertForQuestionAnswering"),
  324. ("bert", "TFBertForQuestionAnswering"),
  325. ("camembert", "TFCamembertForQuestionAnswering"),
  326. ("convbert", "TFConvBertForQuestionAnswering"),
  327. ("deberta", "TFDebertaForQuestionAnswering"),
  328. ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
  329. ("distilbert", "TFDistilBertForQuestionAnswering"),
  330. ("electra", "TFElectraForQuestionAnswering"),
  331. ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
  332. ("funnel", "TFFunnelForQuestionAnswering"),
  333. ("gptj", "TFGPTJForQuestionAnswering"),
  334. ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
  335. ("longformer", "TFLongformerForQuestionAnswering"),
  336. ("mobilebert", "TFMobileBertForQuestionAnswering"),
  337. ("mpnet", "TFMPNetForQuestionAnswering"),
  338. ("rembert", "TFRemBertForQuestionAnswering"),
  339. ("roberta", "TFRobertaForQuestionAnswering"),
  340. ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
  341. ("roformer", "TFRoFormerForQuestionAnswering"),
  342. ("xlm", "TFXLMForQuestionAnsweringSimple"),
  343. ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
  344. ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
  345. ]
  346. )
  347. TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
  348. TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  349. [
  350. ("layoutlm", "TFLayoutLMForQuestionAnswering"),
  351. ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
  352. ]
  353. )
  354. TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
  355. [
  356. # Model for Table Question Answering mapping
  357. ("tapas", "TFTapasForQuestionAnswering"),
  358. ]
  359. )
  360. TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
  361. [
  362. # Model for Token Classification mapping
  363. ("albert", "TFAlbertForTokenClassification"),
  364. ("bert", "TFBertForTokenClassification"),
  365. ("camembert", "TFCamembertForTokenClassification"),
  366. ("convbert", "TFConvBertForTokenClassification"),
  367. ("deberta", "TFDebertaForTokenClassification"),
  368. ("deberta-v2", "TFDebertaV2ForTokenClassification"),
  369. ("distilbert", "TFDistilBertForTokenClassification"),
  370. ("electra", "TFElectraForTokenClassification"),
  371. ("esm", "TFEsmForTokenClassification"),
  372. ("flaubert", "TFFlaubertForTokenClassification"),
  373. ("funnel", "TFFunnelForTokenClassification"),
  374. ("layoutlm", "TFLayoutLMForTokenClassification"),
  375. ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
  376. ("longformer", "TFLongformerForTokenClassification"),
  377. ("mobilebert", "TFMobileBertForTokenClassification"),
  378. ("mpnet", "TFMPNetForTokenClassification"),
  379. ("rembert", "TFRemBertForTokenClassification"),
  380. ("roberta", "TFRobertaForTokenClassification"),
  381. ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
  382. ("roformer", "TFRoFormerForTokenClassification"),
  383. ("xlm", "TFXLMForTokenClassification"),
  384. ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
  385. ("xlnet", "TFXLNetForTokenClassification"),
  386. ]
  387. )
  388. TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
  389. [
  390. # Model for Multiple Choice mapping
  391. ("albert", "TFAlbertForMultipleChoice"),
  392. ("bert", "TFBertForMultipleChoice"),
  393. ("camembert", "TFCamembertForMultipleChoice"),
  394. ("convbert", "TFConvBertForMultipleChoice"),
  395. ("deberta-v2", "TFDebertaV2ForMultipleChoice"),
  396. ("distilbert", "TFDistilBertForMultipleChoice"),
  397. ("electra", "TFElectraForMultipleChoice"),
  398. ("flaubert", "TFFlaubertForMultipleChoice"),
  399. ("funnel", "TFFunnelForMultipleChoice"),
  400. ("longformer", "TFLongformerForMultipleChoice"),
  401. ("mobilebert", "TFMobileBertForMultipleChoice"),
  402. ("mpnet", "TFMPNetForMultipleChoice"),
  403. ("rembert", "TFRemBertForMultipleChoice"),
  404. ("roberta", "TFRobertaForMultipleChoice"),
  405. ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
  406. ("roformer", "TFRoFormerForMultipleChoice"),
  407. ("xlm", "TFXLMForMultipleChoice"),
  408. ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
  409. ("xlnet", "TFXLNetForMultipleChoice"),
  410. ]
  411. )
  412. TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
  413. [
  414. ("bert", "TFBertForNextSentencePrediction"),
  415. ("mobilebert", "TFMobileBertForNextSentencePrediction"),
  416. ]
  417. )
  418. TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
  419. [
  420. ("sam", "TFSamModel"),
  421. ]
  422. )
  423. TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
  424. [
  425. ("albert", "TFAlbertModel"),
  426. ("bert", "TFBertModel"),
  427. ("convbert", "TFConvBertModel"),
  428. ("deberta", "TFDebertaModel"),
  429. ("deberta-v2", "TFDebertaV2Model"),
  430. ("distilbert", "TFDistilBertModel"),
  431. ("electra", "TFElectraModel"),
  432. ("flaubert", "TFFlaubertModel"),
  433. ("longformer", "TFLongformerModel"),
  434. ("mobilebert", "TFMobileBertModel"),
  435. ("mt5", "TFMT5EncoderModel"),
  436. ("rembert", "TFRemBertModel"),
  437. ("roberta", "TFRobertaModel"),
  438. ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
  439. ("roformer", "TFRoFormerModel"),
  440. ("t5", "TFT5EncoderModel"),
  441. ("xlm", "TFXLMModel"),
  442. ("xlm-roberta", "TFXLMRobertaModel"),
  443. ]
  444. )
  445. TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
  446. TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
  447. TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
  448. TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
  449. TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
  450. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
  451. )
  452. TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  453. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
  454. )
  455. TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  456. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
  457. )
  458. TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
  459. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
  460. )
  461. TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
  462. TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
  463. TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
  464. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  465. )
  466. TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  467. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
  468. )
  469. TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
  470. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
  471. )
  472. TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  473. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
  474. )
  475. TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  476. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
  477. )
  478. TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
  479. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
  480. )
  481. TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  482. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  483. )
  484. TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
  485. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
  486. )
  487. TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
  488. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
  489. )
  490. TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
  491. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
  492. )
  493. TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
  494. CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
  495. )
  496. TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
  497. class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
  498. _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
  499. class TFAutoModelForTextEncoding(_BaseAutoModelClass):
  500. _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
  501. class TFAutoModel(_BaseAutoModelClass):
  502. _model_mapping = TF_MODEL_MAPPING
  503. TFAutoModel = auto_class_update(TFAutoModel)
  504. class TFAutoModelForAudioClassification(_BaseAutoModelClass):
  505. _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
  506. TFAutoModelForAudioClassification = auto_class_update(
  507. TFAutoModelForAudioClassification, head_doc="audio classification"
  508. )
  509. class TFAutoModelForPreTraining(_BaseAutoModelClass):
  510. _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
  511. TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
  512. # Private on purpose, the public class will add the deprecation warnings.
  513. class _TFAutoModelWithLMHead(_BaseAutoModelClass):
  514. _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
  515. _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
  516. class TFAutoModelForCausalLM(_BaseAutoModelClass):
  517. _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
  518. TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
  519. class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
  520. _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
  521. TFAutoModelForMaskedImageModeling = auto_class_update(
  522. TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
  523. )
  524. class TFAutoModelForImageClassification(_BaseAutoModelClass):
  525. _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
  526. TFAutoModelForImageClassification = auto_class_update(
  527. TFAutoModelForImageClassification, head_doc="image classification"
  528. )
  529. class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
  530. _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
  531. TFAutoModelForZeroShotImageClassification = auto_class_update(
  532. TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
  533. )
  534. class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
  535. _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
  536. TFAutoModelForSemanticSegmentation = auto_class_update(
  537. TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
  538. )
  539. class TFAutoModelForVision2Seq(_BaseAutoModelClass):
  540. _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
  541. TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
  542. class TFAutoModelForMaskedLM(_BaseAutoModelClass):
  543. _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
  544. TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
  545. class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
  546. _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
  547. TFAutoModelForSeq2SeqLM = auto_class_update(
  548. TFAutoModelForSeq2SeqLM,
  549. head_doc="sequence-to-sequence language modeling",
  550. checkpoint_for_example="google-t5/t5-base",
  551. )
  552. class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
  553. _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
  554. TFAutoModelForSequenceClassification = auto_class_update(
  555. TFAutoModelForSequenceClassification, head_doc="sequence classification"
  556. )
  557. class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
  558. _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
  559. TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
  560. class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
  561. _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
  562. TFAutoModelForDocumentQuestionAnswering = auto_class_update(
  563. TFAutoModelForDocumentQuestionAnswering,
  564. head_doc="document question answering",
  565. checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
  566. )
  567. class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
  568. _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
  569. TFAutoModelForTableQuestionAnswering = auto_class_update(
  570. TFAutoModelForTableQuestionAnswering,
  571. head_doc="table question answering",
  572. checkpoint_for_example="google/tapas-base-finetuned-wtq",
  573. )
  574. class TFAutoModelForTokenClassification(_BaseAutoModelClass):
  575. _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
  576. TFAutoModelForTokenClassification = auto_class_update(
  577. TFAutoModelForTokenClassification, head_doc="token classification"
  578. )
  579. class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
  580. _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
  581. TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
  582. class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
  583. _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
  584. TFAutoModelForNextSentencePrediction = auto_class_update(
  585. TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
  586. )
  587. class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
  588. _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
  589. TFAutoModelForSpeechSeq2Seq = auto_class_update(
  590. TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
  591. )
  592. class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
  593. @classmethod
  594. def from_config(cls, config):
  595. warnings.warn(
  596. "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
  597. " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
  598. " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
  599. FutureWarning,
  600. )
  601. return super().from_config(config)
  602. @classmethod
  603. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  604. warnings.warn(
  605. "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
  606. " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
  607. " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
  608. FutureWarning,
  609. )
  610. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)