auto_factory.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Factory function to build auto-model classes."""
  16. import copy
  17. import importlib
  18. import json
  19. import warnings
  20. from collections import OrderedDict
  21. from ...configuration_utils import PretrainedConfig
  22. from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
  23. from ...utils import (
  24. CONFIG_NAME,
  25. cached_file,
  26. copy_func,
  27. extract_commit_hash,
  28. find_adapter_config_file,
  29. is_peft_available,
  30. is_torch_available,
  31. logging,
  32. requires_backends,
  33. )
  34. from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
  35. if is_torch_available():
  36. from ...generation import GenerationMixin
  37. logger = logging.get_logger(__name__)
  38. CLASS_DOCSTRING = """
  39. This is a generic model class that will be instantiated as one of the model classes of the library when created
  40. with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
  41. method.
  42. This class cannot be instantiated directly using `__init__()` (throws an error).
  43. """
  44. FROM_CONFIG_DOCSTRING = """
  45. Instantiates one of the model classes of the library from a configuration.
  46. Note:
  47. Loading a model from its configuration file does **not** load the model weights. It only affects the
  48. model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
  49. Args:
  50. config ([`PretrainedConfig`]):
  51. The model class to instantiate is selected based on the configuration class:
  52. List options
  53. attn_implementation (`str`, *optional*):
  54. The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
  55. Examples:
  56. ```python
  57. >>> from transformers import AutoConfig, BaseAutoModelClass
  58. >>> # Download configuration from huggingface.co and cache.
  59. >>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
  60. >>> model = BaseAutoModelClass.from_config(config)
  61. ```
  62. """
  63. FROM_PRETRAINED_TORCH_DOCSTRING = """
  64. Instantiate one of the model classes of the library from a pretrained model.
  65. The model class to instantiate is selected based on the `model_type` property of the config object (either
  66. passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
  67. falling back to using pattern matching on `pretrained_model_name_or_path`:
  68. List options
  69. The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
  70. deactivated). To train the model, you should first set it back in training mode with `model.train()`
  71. Args:
  72. pretrained_model_name_or_path (`str` or `os.PathLike`):
  73. Can be either:
  74. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  75. - A path to a *directory* containing model weights saved using
  76. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  77. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
  78. this case, `from_tf` should be set to `True` and a configuration object should be provided as
  79. `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
  80. PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
  81. model_args (additional positional arguments, *optional*):
  82. Will be passed along to the underlying model `__init__()` method.
  83. config ([`PretrainedConfig`], *optional*):
  84. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  85. be automatically loaded when:
  86. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  87. model).
  88. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  89. save directory.
  90. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  91. configuration JSON file named *config.json* is found in the directory.
  92. state_dict (*Dict[str, torch.Tensor]*, *optional*):
  93. A state dictionary to use instead of a state dictionary loaded from saved weights file.
  94. This option can be used if you want to create a model from a pretrained configuration but load your own
  95. weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
  96. [`~PreTrainedModel.from_pretrained`] is not a simpler option.
  97. cache_dir (`str` or `os.PathLike`, *optional*):
  98. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  99. standard cache should not be used.
  100. from_tf (`bool`, *optional*, defaults to `False`):
  101. Load the model weights from a TensorFlow checkpoint save file (see docstring of
  102. `pretrained_model_name_or_path` argument).
  103. force_download (`bool`, *optional*, defaults to `False`):
  104. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  105. cached versions if they exist.
  106. resume_download:
  107. Deprecated and ignored. All downloads are now resumed by default when possible.
  108. Will be removed in v5 of Transformers.
  109. proxies (`Dict[str, str]`, *optional*):
  110. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  111. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  112. output_loading_info(`bool`, *optional*, defaults to `False`):
  113. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
  114. local_files_only(`bool`, *optional*, defaults to `False`):
  115. Whether or not to only look at local files (e.g., not try downloading the model).
  116. revision (`str`, *optional*, defaults to `"main"`):
  117. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  118. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  119. identifier allowed by git.
  120. trust_remote_code (`bool`, *optional*, defaults to `False`):
  121. Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
  122. should only be set to `True` for repositories you trust and in which you have read the code, as it will
  123. execute code present on the Hub on your local machine.
  124. code_revision (`str`, *optional*, defaults to `"main"`):
  125. The specific revision to use for the code on the Hub, if the code leaves in a different repository than
  126. the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
  127. system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
  128. allowed by git.
  129. kwargs (additional keyword arguments, *optional*):
  130. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  131. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  132. automatically loaded:
  133. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  134. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  135. already been done)
  136. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  137. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  138. corresponds to a configuration attribute will be used to override said attribute with the
  139. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  140. will be passed to the underlying model's `__init__` function.
  141. Examples:
  142. ```python
  143. >>> from transformers import AutoConfig, BaseAutoModelClass
  144. >>> # Download model and configuration from huggingface.co and cache.
  145. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
  146. >>> # Update configuration during loading
  147. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
  148. >>> model.config.output_attentions
  149. True
  150. >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
  151. >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
  152. >>> model = BaseAutoModelClass.from_pretrained(
  153. ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
  154. ... )
  155. ```
  156. """
  157. FROM_PRETRAINED_TF_DOCSTRING = """
  158. Instantiate one of the model classes of the library from a pretrained model.
  159. The model class to instantiate is selected based on the `model_type` property of the config object (either
  160. passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
  161. falling back to using pattern matching on `pretrained_model_name_or_path`:
  162. List options
  163. Args:
  164. pretrained_model_name_or_path (`str` or `os.PathLike`):
  165. Can be either:
  166. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  167. - A path to a *directory* containing model weights saved using
  168. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  169. - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
  170. case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
  171. argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
  172. using the provided conversion scripts and loading the TensorFlow model afterwards.
  173. model_args (additional positional arguments, *optional*):
  174. Will be passed along to the underlying model `__init__()` method.
  175. config ([`PretrainedConfig`], *optional*):
  176. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  177. be automatically loaded when:
  178. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  179. model).
  180. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  181. save directory.
  182. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  183. configuration JSON file named *config.json* is found in the directory.
  184. cache_dir (`str` or `os.PathLike`, *optional*):
  185. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  186. standard cache should not be used.
  187. from_pt (`bool`, *optional*, defaults to `False`):
  188. Load the model weights from a PyTorch checkpoint save file (see docstring of
  189. `pretrained_model_name_or_path` argument).
  190. force_download (`bool`, *optional*, defaults to `False`):
  191. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  192. cached versions if they exist.
  193. resume_download:
  194. Deprecated and ignored. All downloads are now resumed by default when possible.
  195. Will be removed in v5 of Transformers.
  196. proxies (`Dict[str, str]`, *optional*):
  197. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  198. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  199. output_loading_info(`bool`, *optional*, defaults to `False`):
  200. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
  201. local_files_only(`bool`, *optional*, defaults to `False`):
  202. Whether or not to only look at local files (e.g., not try downloading the model).
  203. revision (`str`, *optional*, defaults to `"main"`):
  204. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  205. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  206. identifier allowed by git.
  207. trust_remote_code (`bool`, *optional*, defaults to `False`):
  208. Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
  209. should only be set to `True` for repositories you trust and in which you have read the code, as it will
  210. execute code present on the Hub on your local machine.
  211. code_revision (`str`, *optional*, defaults to `"main"`):
  212. The specific revision to use for the code on the Hub, if the code leaves in a different repository than
  213. the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
  214. system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
  215. allowed by git.
  216. kwargs (additional keyword arguments, *optional*):
  217. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  218. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  219. automatically loaded:
  220. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  221. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  222. already been done)
  223. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  224. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  225. corresponds to a configuration attribute will be used to override said attribute with the
  226. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  227. will be passed to the underlying model's `__init__` function.
  228. Examples:
  229. ```python
  230. >>> from transformers import AutoConfig, BaseAutoModelClass
  231. >>> # Download model and configuration from huggingface.co and cache.
  232. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
  233. >>> # Update configuration during loading
  234. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
  235. >>> model.config.output_attentions
  236. True
  237. >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
  238. >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
  239. >>> model = BaseAutoModelClass.from_pretrained(
  240. ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
  241. ... )
  242. ```
  243. """
  244. FROM_PRETRAINED_FLAX_DOCSTRING = """
  245. Instantiate one of the model classes of the library from a pretrained model.
  246. The model class to instantiate is selected based on the `model_type` property of the config object (either
  247. passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
  248. falling back to using pattern matching on `pretrained_model_name_or_path`:
  249. List options
  250. Args:
  251. pretrained_model_name_or_path (`str` or `os.PathLike`):
  252. Can be either:
  253. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  254. - A path to a *directory* containing model weights saved using
  255. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  256. - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
  257. case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
  258. argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
  259. using the provided conversion scripts and loading the TensorFlow model afterwards.
  260. model_args (additional positional arguments, *optional*):
  261. Will be passed along to the underlying model `__init__()` method.
  262. config ([`PretrainedConfig`], *optional*):
  263. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  264. be automatically loaded when:
  265. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  266. model).
  267. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  268. save directory.
  269. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  270. configuration JSON file named *config.json* is found in the directory.
  271. cache_dir (`str` or `os.PathLike`, *optional*):
  272. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  273. standard cache should not be used.
  274. from_pt (`bool`, *optional*, defaults to `False`):
  275. Load the model weights from a PyTorch checkpoint save file (see docstring of
  276. `pretrained_model_name_or_path` argument).
  277. force_download (`bool`, *optional*, defaults to `False`):
  278. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  279. cached versions if they exist.
  280. resume_download:
  281. Deprecated and ignored. All downloads are now resumed by default when possible.
  282. Will be removed in v5 of Transformers.
  283. proxies (`Dict[str, str]`, *optional*):
  284. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  285. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  286. output_loading_info(`bool`, *optional*, defaults to `False`):
  287. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
  288. local_files_only(`bool`, *optional*, defaults to `False`):
  289. Whether or not to only look at local files (e.g., not try downloading the model).
  290. revision (`str`, *optional*, defaults to `"main"`):
  291. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  292. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  293. identifier allowed by git.
  294. trust_remote_code (`bool`, *optional*, defaults to `False`):
  295. Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
  296. should only be set to `True` for repositories you trust and in which you have read the code, as it will
  297. execute code present on the Hub on your local machine.
  298. code_revision (`str`, *optional*, defaults to `"main"`):
  299. The specific revision to use for the code on the Hub, if the code leaves in a different repository than
  300. the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
  301. system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
  302. allowed by git.
  303. kwargs (additional keyword arguments, *optional*):
  304. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  305. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  306. automatically loaded:
  307. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  308. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  309. already been done)
  310. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  311. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  312. corresponds to a configuration attribute will be used to override said attribute with the
  313. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  314. will be passed to the underlying model's `__init__` function.
  315. Examples:
  316. ```python
  317. >>> from transformers import AutoConfig, BaseAutoModelClass
  318. >>> # Download model and configuration from huggingface.co and cache.
  319. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
  320. >>> # Update configuration during loading
  321. >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
  322. >>> model.config.output_attentions
  323. True
  324. >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
  325. >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
  326. >>> model = BaseAutoModelClass.from_pretrained(
  327. ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
  328. ... )
  329. ```
  330. """
  331. def _get_model_class(config, model_mapping):
  332. supported_models = model_mapping[type(config)]
  333. if not isinstance(supported_models, (list, tuple)):
  334. return supported_models
  335. name_to_model = {model.__name__: model for model in supported_models}
  336. architectures = getattr(config, "architectures", [])
  337. for arch in architectures:
  338. if arch in name_to_model:
  339. return name_to_model[arch]
  340. elif f"TF{arch}" in name_to_model:
  341. return name_to_model[f"TF{arch}"]
  342. elif f"Flax{arch}" in name_to_model:
  343. return name_to_model[f"Flax{arch}"]
  344. # If not architecture is set in the config or match the supported models, the first element of the tuple is the
  345. # defaults.
  346. return supported_models[0]
  347. class _BaseAutoModelClass:
  348. # Base class for auto models.
  349. _model_mapping = None
  350. def __init__(self, *args, **kwargs):
  351. raise EnvironmentError(
  352. f"{self.__class__.__name__} is designed to be instantiated "
  353. f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
  354. f"`{self.__class__.__name__}.from_config(config)` methods."
  355. )
  356. @classmethod
  357. def from_config(cls, config, **kwargs):
  358. trust_remote_code = kwargs.pop("trust_remote_code", None)
  359. has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
  360. has_local_code = type(config) in cls._model_mapping.keys()
  361. trust_remote_code = resolve_trust_remote_code(
  362. trust_remote_code, config._name_or_path, has_local_code, has_remote_code
  363. )
  364. if has_remote_code and trust_remote_code:
  365. class_ref = config.auto_map[cls.__name__]
  366. if "--" in class_ref:
  367. repo_id, class_ref = class_ref.split("--")
  368. else:
  369. repo_id = config.name_or_path
  370. model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
  371. cls.register(config.__class__, model_class, exist_ok=True)
  372. _ = kwargs.pop("code_revision", None)
  373. model_class = add_generation_mixin_to_remote_model(model_class)
  374. return model_class._from_config(config, **kwargs)
  375. elif type(config) in cls._model_mapping.keys():
  376. model_class = _get_model_class(config, cls._model_mapping)
  377. return model_class._from_config(config, **kwargs)
  378. raise ValueError(
  379. f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
  380. f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
  381. )
  382. @classmethod
  383. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  384. config = kwargs.pop("config", None)
  385. trust_remote_code = kwargs.pop("trust_remote_code", None)
  386. kwargs["_from_auto"] = True
  387. hub_kwargs_names = [
  388. "cache_dir",
  389. "force_download",
  390. "local_files_only",
  391. "proxies",
  392. "resume_download",
  393. "revision",
  394. "subfolder",
  395. "use_auth_token",
  396. "token",
  397. ]
  398. hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
  399. code_revision = kwargs.pop("code_revision", None)
  400. commit_hash = kwargs.pop("_commit_hash", None)
  401. adapter_kwargs = kwargs.pop("adapter_kwargs", None)
  402. token = hub_kwargs.pop("token", None)
  403. use_auth_token = hub_kwargs.pop("use_auth_token", None)
  404. if use_auth_token is not None:
  405. warnings.warn(
  406. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  407. FutureWarning,
  408. )
  409. if token is not None:
  410. raise ValueError(
  411. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  412. )
  413. token = use_auth_token
  414. if token is not None:
  415. hub_kwargs["token"] = token
  416. if commit_hash is None:
  417. if not isinstance(config, PretrainedConfig):
  418. # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
  419. resolved_config_file = cached_file(
  420. pretrained_model_name_or_path,
  421. CONFIG_NAME,
  422. _raise_exceptions_for_gated_repo=False,
  423. _raise_exceptions_for_missing_entries=False,
  424. _raise_exceptions_for_connection_errors=False,
  425. **hub_kwargs,
  426. )
  427. commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
  428. else:
  429. commit_hash = getattr(config, "_commit_hash", None)
  430. if is_peft_available():
  431. if adapter_kwargs is None:
  432. adapter_kwargs = {}
  433. if token is not None:
  434. adapter_kwargs["token"] = token
  435. maybe_adapter_path = find_adapter_config_file(
  436. pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
  437. )
  438. if maybe_adapter_path is not None:
  439. with open(maybe_adapter_path, "r", encoding="utf-8") as f:
  440. adapter_config = json.load(f)
  441. adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
  442. pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
  443. if not isinstance(config, PretrainedConfig):
  444. kwargs_orig = copy.deepcopy(kwargs)
  445. # ensure not to pollute the config object with torch_dtype="auto" - since it's
  446. # meaningless in the context of the config object - torch.dtype values are acceptable
  447. if kwargs.get("torch_dtype", None) == "auto":
  448. _ = kwargs.pop("torch_dtype")
  449. # to not overwrite the quantization_config if config has a quantization_config
  450. if kwargs.get("quantization_config", None) is not None:
  451. _ = kwargs.pop("quantization_config")
  452. config, kwargs = AutoConfig.from_pretrained(
  453. pretrained_model_name_or_path,
  454. return_unused_kwargs=True,
  455. trust_remote_code=trust_remote_code,
  456. code_revision=code_revision,
  457. _commit_hash=commit_hash,
  458. **hub_kwargs,
  459. **kwargs,
  460. )
  461. # if torch_dtype=auto was passed here, ensure to pass it on
  462. if kwargs_orig.get("torch_dtype", None) == "auto":
  463. kwargs["torch_dtype"] = "auto"
  464. if kwargs_orig.get("quantization_config", None) is not None:
  465. kwargs["quantization_config"] = kwargs_orig["quantization_config"]
  466. has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
  467. has_local_code = type(config) in cls._model_mapping.keys()
  468. trust_remote_code = resolve_trust_remote_code(
  469. trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
  470. )
  471. # Set the adapter kwargs
  472. kwargs["adapter_kwargs"] = adapter_kwargs
  473. if has_remote_code and trust_remote_code:
  474. class_ref = config.auto_map[cls.__name__]
  475. model_class = get_class_from_dynamic_module(
  476. class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
  477. )
  478. _ = hub_kwargs.pop("code_revision", None)
  479. cls.register(config.__class__, model_class, exist_ok=True)
  480. model_class = add_generation_mixin_to_remote_model(model_class)
  481. return model_class.from_pretrained(
  482. pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
  483. )
  484. elif type(config) in cls._model_mapping.keys():
  485. model_class = _get_model_class(config, cls._model_mapping)
  486. return model_class.from_pretrained(
  487. pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
  488. )
  489. raise ValueError(
  490. f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
  491. f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
  492. )
  493. @classmethod
  494. def register(cls, config_class, model_class, exist_ok=False):
  495. """
  496. Register a new model for this class.
  497. Args:
  498. config_class ([`PretrainedConfig`]):
  499. The configuration corresponding to the model to register.
  500. model_class ([`PreTrainedModel`]):
  501. The model to register.
  502. """
  503. if hasattr(model_class, "config_class") and str(model_class.config_class) != str(config_class):
  504. raise ValueError(
  505. "The model class you are passing has a `config_class` attribute that is not consistent with the "
  506. f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
  507. "one of those so they match!"
  508. )
  509. cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
  510. class _BaseAutoBackboneClass(_BaseAutoModelClass):
  511. # Base class for auto backbone models.
  512. _model_mapping = None
  513. @classmethod
  514. def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  515. requires_backends(cls, ["vision", "timm"])
  516. from ...models.timm_backbone import TimmBackboneConfig
  517. config = kwargs.pop("config", TimmBackboneConfig())
  518. if kwargs.get("out_features", None) is not None:
  519. raise ValueError("Cannot specify `out_features` for timm backbones")
  520. if kwargs.get("output_loading_info", False):
  521. raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
  522. num_channels = kwargs.pop("num_channels", config.num_channels)
  523. features_only = kwargs.pop("features_only", config.features_only)
  524. use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
  525. out_indices = kwargs.pop("out_indices", config.out_indices)
  526. config = TimmBackboneConfig(
  527. backbone=pretrained_model_name_or_path,
  528. num_channels=num_channels,
  529. features_only=features_only,
  530. use_pretrained_backbone=use_pretrained_backbone,
  531. out_indices=out_indices,
  532. )
  533. return super().from_config(config, **kwargs)
  534. @classmethod
  535. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  536. use_timm_backbone = kwargs.pop("use_timm_backbone", False)
  537. if use_timm_backbone:
  538. return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  539. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  540. def insert_head_doc(docstring, head_doc=""):
  541. if len(head_doc) > 0:
  542. return docstring.replace(
  543. "one of the model classes of the library ",
  544. f"one of the model classes of the library (with a {head_doc} head) ",
  545. )
  546. return docstring.replace(
  547. "one of the model classes of the library ", "one of the base model classes of the library "
  548. )
  549. def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
  550. # Create a new class with the right name from the base class
  551. model_mapping = cls._model_mapping
  552. name = cls.__name__
  553. class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
  554. cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
  555. # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
  556. # have a specific docstrings for them.
  557. from_config = copy_func(_BaseAutoModelClass.from_config)
  558. from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
  559. from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
  560. from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
  561. from_config.__doc__ = from_config_docstring
  562. from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
  563. cls.from_config = classmethod(from_config)
  564. if name.startswith("TF"):
  565. from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
  566. elif name.startswith("Flax"):
  567. from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
  568. else:
  569. from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
  570. from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
  571. from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
  572. from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
  573. from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
  574. shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
  575. from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
  576. from_pretrained.__doc__ = from_pretrained_docstring
  577. from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
  578. cls.from_pretrained = classmethod(from_pretrained)
  579. return cls
  580. def get_values(model_mapping):
  581. result = []
  582. for model in model_mapping.values():
  583. if isinstance(model, (list, tuple)):
  584. result += list(model)
  585. else:
  586. result.append(model)
  587. return result
  588. def getattribute_from_module(module, attr):
  589. if attr is None:
  590. return None
  591. if isinstance(attr, tuple):
  592. return tuple(getattribute_from_module(module, a) for a in attr)
  593. if hasattr(module, attr):
  594. return getattr(module, attr)
  595. # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
  596. # object at the top level.
  597. transformers_module = importlib.import_module("transformers")
  598. if module != transformers_module:
  599. try:
  600. return getattribute_from_module(transformers_module, attr)
  601. except ValueError:
  602. raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
  603. else:
  604. raise ValueError(f"Could not find {attr} in {transformers_module}!")
  605. def add_generation_mixin_to_remote_model(model_class):
  606. """
  607. Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
  608. This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
  609. `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
  610. from the Hub may not have the `generate` method after we remove the inheritance.
  611. """
  612. # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
  613. if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
  614. return model_class
  615. # 2. If it already **directly** inherits from GenerationMixin, do nothing
  616. if "GenerationMixin" in str(model_class.__bases__):
  617. return model_class
  618. # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
  619. # `prepare_inputs_for_generation` method.
  620. has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
  621. has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
  622. if has_custom_generate or has_custom_prepare_inputs:
  623. model_class_with_generation_mixin = type(
  624. model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
  625. )
  626. return model_class_with_generation_mixin
  627. return model_class
  628. class _LazyAutoMapping(OrderedDict):
  629. """
  630. " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
  631. Args:
  632. - config_mapping: The map model type to config class
  633. - model_mapping: The map model type to model (or tokenizer) class
  634. """
  635. def __init__(self, config_mapping, model_mapping):
  636. self._config_mapping = config_mapping
  637. self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
  638. self._model_mapping = model_mapping
  639. self._model_mapping._model_mapping = self
  640. self._extra_content = {}
  641. self._modules = {}
  642. def __len__(self):
  643. common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
  644. return len(common_keys) + len(self._extra_content)
  645. def __getitem__(self, key):
  646. if key in self._extra_content:
  647. return self._extra_content[key]
  648. model_type = self._reverse_config_mapping[key.__name__]
  649. if model_type in self._model_mapping:
  650. model_name = self._model_mapping[model_type]
  651. return self._load_attr_from_module(model_type, model_name)
  652. # Maybe there was several model types associated with this config.
  653. model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
  654. for mtype in model_types:
  655. if mtype in self._model_mapping:
  656. model_name = self._model_mapping[mtype]
  657. return self._load_attr_from_module(mtype, model_name)
  658. raise KeyError(key)
  659. def _load_attr_from_module(self, model_type, attr):
  660. module_name = model_type_to_module_name(model_type)
  661. if module_name not in self._modules:
  662. self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
  663. return getattribute_from_module(self._modules[module_name], attr)
  664. def keys(self):
  665. mapping_keys = [
  666. self._load_attr_from_module(key, name)
  667. for key, name in self._config_mapping.items()
  668. if key in self._model_mapping.keys()
  669. ]
  670. return mapping_keys + list(self._extra_content.keys())
  671. def get(self, key, default):
  672. try:
  673. return self.__getitem__(key)
  674. except KeyError:
  675. return default
  676. def __bool__(self):
  677. return bool(self.keys())
  678. def values(self):
  679. mapping_values = [
  680. self._load_attr_from_module(key, name)
  681. for key, name in self._model_mapping.items()
  682. if key in self._config_mapping.keys()
  683. ]
  684. return mapping_values + list(self._extra_content.values())
  685. def items(self):
  686. mapping_items = [
  687. (
  688. self._load_attr_from_module(key, self._config_mapping[key]),
  689. self._load_attr_from_module(key, self._model_mapping[key]),
  690. )
  691. for key in self._model_mapping.keys()
  692. if key in self._config_mapping.keys()
  693. ]
  694. return mapping_items + list(self._extra_content.items())
  695. def __iter__(self):
  696. return iter(self.keys())
  697. def __contains__(self, item):
  698. if item in self._extra_content:
  699. return True
  700. if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
  701. return False
  702. model_type = self._reverse_config_mapping[item.__name__]
  703. return model_type in self._model_mapping
  704. def register(self, key, value, exist_ok=False):
  705. """
  706. Register a new model in this mapping.
  707. """
  708. if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
  709. model_type = self._reverse_config_mapping[key.__name__]
  710. if model_type in self._model_mapping.keys() and not exist_ok:
  711. raise ValueError(f"'{key}' is already used by a Transformers model.")
  712. self._extra_content[key] = value