| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835 |
- # coding=utf-8
- # Copyright 2021 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Factory function to build auto-model classes."""
- import copy
- import importlib
- import json
- import warnings
- from collections import OrderedDict
- from ...configuration_utils import PretrainedConfig
- from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
- from ...utils import (
- CONFIG_NAME,
- cached_file,
- copy_func,
- extract_commit_hash,
- find_adapter_config_file,
- is_peft_available,
- is_torch_available,
- logging,
- requires_backends,
- )
- from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
- if is_torch_available():
- from ...generation import GenerationMixin
- logger = logging.get_logger(__name__)
- CLASS_DOCSTRING = """
- This is a generic model class that will be instantiated as one of the model classes of the library when created
- with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class
- method.
- This class cannot be instantiated directly using `__init__()` (throws an error).
- """
- FROM_CONFIG_DOCSTRING = """
- Instantiates one of the model classes of the library from a configuration.
- Note:
- Loading a model from its configuration file does **not** load the model weights. It only affects the
- model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights.
- Args:
- config ([`PretrainedConfig`]):
- The model class to instantiate is selected based on the configuration class:
- List options
- attn_implementation (`str`, *optional*):
- The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
- Examples:
- ```python
- >>> from transformers import AutoConfig, BaseAutoModelClass
- >>> # Download configuration from huggingface.co and cache.
- >>> config = AutoConfig.from_pretrained("checkpoint_placeholder")
- >>> model = BaseAutoModelClass.from_config(config)
- ```
- """
- FROM_PRETRAINED_TORCH_DOCSTRING = """
- Instantiate one of the model classes of the library from a pretrained model.
- The model class to instantiate is selected based on the `model_type` property of the config object (either
- passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
- falling back to using pattern matching on `pretrained_model_name_or_path`:
- List options
- The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are
- deactivated). To train the model, you should first set it back in training mode with `model.train()`
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
- this case, `from_tf` should be set to `True` and a configuration object should be provided as
- `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
- PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- model_args (additional positional arguments, *optional*):
- Will be passed along to the underlying model `__init__()` method.
- config ([`PretrainedConfig`], *optional*):
- Configuration for the model to use instead of an automatically loaded configuration. Configuration can
- be automatically loaded when:
- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
- model).
- - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- save directory.
- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- configuration JSON file named *config.json* is found in the directory.
- state_dict (*Dict[str, torch.Tensor]*, *optional*):
- A state dictionary to use instead of a state dictionary loaded from saved weights file.
- This option can be used if you want to create a model from a pretrained configuration but load your own
- weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
- [`~PreTrainedModel.from_pretrained`] is not a simpler option.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- from_tf (`bool`, *optional*, defaults to `False`):
- Load the model weights from a TensorFlow checkpoint save file (see docstring of
- `pretrained_model_name_or_path` argument).
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- output_loading_info(`bool`, *optional*, defaults to `False`):
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (e.g., not try downloading the model).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- trust_remote_code (`bool`, *optional*, defaults to `False`):
- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
- should only be set to `True` for repositories you trust and in which you have read the code, as it will
- execute code present on the Hub on your local machine.
- code_revision (`str`, *optional*, defaults to `"main"`):
- The specific revision to use for the code on the Hub, if the code leaves in a different repository than
- the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
- system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
- allowed by git.
- kwargs (additional keyword arguments, *optional*):
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
- automatically loaded:
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
- already been done)
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
- corresponds to a configuration attribute will be used to override said attribute with the
- supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
- will be passed to the underlying model's `__init__` function.
- Examples:
- ```python
- >>> from transformers import AutoConfig, BaseAutoModelClass
- >>> # Download model and configuration from huggingface.co and cache.
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
- >>> # Update configuration during loading
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
- >>> model.config.output_attentions
- True
- >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
- >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json")
- >>> model = BaseAutoModelClass.from_pretrained(
- ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config
- ... )
- ```
- """
- FROM_PRETRAINED_TF_DOCSTRING = """
- Instantiate one of the model classes of the library from a pretrained model.
- The model class to instantiate is selected based on the `model_type` property of the config object (either
- passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
- falling back to using pattern matching on `pretrained_model_name_or_path`:
- List options
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
- case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
- argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
- using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, *optional*):
- Will be passed along to the underlying model `__init__()` method.
- config ([`PretrainedConfig`], *optional*):
- Configuration for the model to use instead of an automatically loaded configuration. Configuration can
- be automatically loaded when:
- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
- model).
- - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- save directory.
- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- configuration JSON file named *config.json* is found in the directory.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- from_pt (`bool`, *optional*, defaults to `False`):
- Load the model weights from a PyTorch checkpoint save file (see docstring of
- `pretrained_model_name_or_path` argument).
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- output_loading_info(`bool`, *optional*, defaults to `False`):
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (e.g., not try downloading the model).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- trust_remote_code (`bool`, *optional*, defaults to `False`):
- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
- should only be set to `True` for repositories you trust and in which you have read the code, as it will
- execute code present on the Hub on your local machine.
- code_revision (`str`, *optional*, defaults to `"main"`):
- The specific revision to use for the code on the Hub, if the code leaves in a different repository than
- the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
- system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
- allowed by git.
- kwargs (additional keyword arguments, *optional*):
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
- automatically loaded:
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
- already been done)
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
- corresponds to a configuration attribute will be used to override said attribute with the
- supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
- will be passed to the underlying model's `__init__` function.
- Examples:
- ```python
- >>> from transformers import AutoConfig, BaseAutoModelClass
- >>> # Download model and configuration from huggingface.co and cache.
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
- >>> # Update configuration during loading
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
- >>> model.config.output_attentions
- True
- >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
- >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
- >>> model = BaseAutoModelClass.from_pretrained(
- ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
- ... )
- ```
- """
- FROM_PRETRAINED_FLAX_DOCSTRING = """
- Instantiate one of the model classes of the library from a pretrained model.
- The model class to instantiate is selected based on the `model_type` property of the config object (either
- passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by
- falling back to using pattern matching on `pretrained_model_name_or_path`:
- List options
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
- case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
- argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
- using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args (additional positional arguments, *optional*):
- Will be passed along to the underlying model `__init__()` method.
- config ([`PretrainedConfig`], *optional*):
- Configuration for the model to use instead of an automatically loaded configuration. Configuration can
- be automatically loaded when:
- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
- model).
- - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- save directory.
- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- configuration JSON file named *config.json* is found in the directory.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- from_pt (`bool`, *optional*, defaults to `False`):
- Load the model weights from a PyTorch checkpoint save file (see docstring of
- `pretrained_model_name_or_path` argument).
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- output_loading_info(`bool`, *optional*, defaults to `False`):
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (e.g., not try downloading the model).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- trust_remote_code (`bool`, *optional*, defaults to `False`):
- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
- should only be set to `True` for repositories you trust and in which you have read the code, as it will
- execute code present on the Hub on your local machine.
- code_revision (`str`, *optional*, defaults to `"main"`):
- The specific revision to use for the code on the Hub, if the code leaves in a different repository than
- the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based
- system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier
- allowed by git.
- kwargs (additional keyword arguments, *optional*):
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
- automatically loaded:
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
- already been done)
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
- corresponds to a configuration attribute will be used to override said attribute with the
- supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
- will be passed to the underlying model's `__init__` function.
- Examples:
- ```python
- >>> from transformers import AutoConfig, BaseAutoModelClass
- >>> # Download model and configuration from huggingface.co and cache.
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder")
- >>> # Update configuration during loading
- >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True)
- >>> model.config.output_attentions
- True
- >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
- >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json")
- >>> model = BaseAutoModelClass.from_pretrained(
- ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config
- ... )
- ```
- """
- def _get_model_class(config, model_mapping):
- supported_models = model_mapping[type(config)]
- if not isinstance(supported_models, (list, tuple)):
- return supported_models
- name_to_model = {model.__name__: model for model in supported_models}
- architectures = getattr(config, "architectures", [])
- for arch in architectures:
- if arch in name_to_model:
- return name_to_model[arch]
- elif f"TF{arch}" in name_to_model:
- return name_to_model[f"TF{arch}"]
- elif f"Flax{arch}" in name_to_model:
- return name_to_model[f"Flax{arch}"]
- # If not architecture is set in the config or match the supported models, the first element of the tuple is the
- # defaults.
- return supported_models[0]
- class _BaseAutoModelClass:
- # Base class for auto models.
- _model_mapping = None
- def __init__(self, *args, **kwargs):
- raise EnvironmentError(
- f"{self.__class__.__name__} is designed to be instantiated "
- f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
- f"`{self.__class__.__name__}.from_config(config)` methods."
- )
- @classmethod
- def from_config(cls, config, **kwargs):
- trust_remote_code = kwargs.pop("trust_remote_code", None)
- has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
- has_local_code = type(config) in cls._model_mapping.keys()
- trust_remote_code = resolve_trust_remote_code(
- trust_remote_code, config._name_or_path, has_local_code, has_remote_code
- )
- if has_remote_code and trust_remote_code:
- class_ref = config.auto_map[cls.__name__]
- if "--" in class_ref:
- repo_id, class_ref = class_ref.split("--")
- else:
- repo_id = config.name_or_path
- model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
- cls.register(config.__class__, model_class, exist_ok=True)
- _ = kwargs.pop("code_revision", None)
- model_class = add_generation_mixin_to_remote_model(model_class)
- return model_class._from_config(config, **kwargs)
- elif type(config) in cls._model_mapping.keys():
- model_class = _get_model_class(config, cls._model_mapping)
- return model_class._from_config(config, **kwargs)
- raise ValueError(
- f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
- f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
- )
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- config = kwargs.pop("config", None)
- trust_remote_code = kwargs.pop("trust_remote_code", None)
- kwargs["_from_auto"] = True
- hub_kwargs_names = [
- "cache_dir",
- "force_download",
- "local_files_only",
- "proxies",
- "resume_download",
- "revision",
- "subfolder",
- "use_auth_token",
- "token",
- ]
- hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
- code_revision = kwargs.pop("code_revision", None)
- commit_hash = kwargs.pop("_commit_hash", None)
- adapter_kwargs = kwargs.pop("adapter_kwargs", None)
- token = hub_kwargs.pop("token", None)
- use_auth_token = hub_kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
- if token is not None:
- hub_kwargs["token"] = token
- if commit_hash is None:
- if not isinstance(config, PretrainedConfig):
- # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
- resolved_config_file = cached_file(
- pretrained_model_name_or_path,
- CONFIG_NAME,
- _raise_exceptions_for_gated_repo=False,
- _raise_exceptions_for_missing_entries=False,
- _raise_exceptions_for_connection_errors=False,
- **hub_kwargs,
- )
- commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
- else:
- commit_hash = getattr(config, "_commit_hash", None)
- if is_peft_available():
- if adapter_kwargs is None:
- adapter_kwargs = {}
- if token is not None:
- adapter_kwargs["token"] = token
- maybe_adapter_path = find_adapter_config_file(
- pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
- )
- if maybe_adapter_path is not None:
- with open(maybe_adapter_path, "r", encoding="utf-8") as f:
- adapter_config = json.load(f)
- adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
- pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
- if not isinstance(config, PretrainedConfig):
- kwargs_orig = copy.deepcopy(kwargs)
- # ensure not to pollute the config object with torch_dtype="auto" - since it's
- # meaningless in the context of the config object - torch.dtype values are acceptable
- if kwargs.get("torch_dtype", None) == "auto":
- _ = kwargs.pop("torch_dtype")
- # to not overwrite the quantization_config if config has a quantization_config
- if kwargs.get("quantization_config", None) is not None:
- _ = kwargs.pop("quantization_config")
- config, kwargs = AutoConfig.from_pretrained(
- pretrained_model_name_or_path,
- return_unused_kwargs=True,
- trust_remote_code=trust_remote_code,
- code_revision=code_revision,
- _commit_hash=commit_hash,
- **hub_kwargs,
- **kwargs,
- )
- # if torch_dtype=auto was passed here, ensure to pass it on
- if kwargs_orig.get("torch_dtype", None) == "auto":
- kwargs["torch_dtype"] = "auto"
- if kwargs_orig.get("quantization_config", None) is not None:
- kwargs["quantization_config"] = kwargs_orig["quantization_config"]
- has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
- has_local_code = type(config) in cls._model_mapping.keys()
- trust_remote_code = resolve_trust_remote_code(
- trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
- )
- # Set the adapter kwargs
- kwargs["adapter_kwargs"] = adapter_kwargs
- if has_remote_code and trust_remote_code:
- class_ref = config.auto_map[cls.__name__]
- model_class = get_class_from_dynamic_module(
- class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
- )
- _ = hub_kwargs.pop("code_revision", None)
- cls.register(config.__class__, model_class, exist_ok=True)
- model_class = add_generation_mixin_to_remote_model(model_class)
- return model_class.from_pretrained(
- pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
- )
- elif type(config) in cls._model_mapping.keys():
- model_class = _get_model_class(config, cls._model_mapping)
- return model_class.from_pretrained(
- pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
- )
- raise ValueError(
- f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
- f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
- )
- @classmethod
- def register(cls, config_class, model_class, exist_ok=False):
- """
- Register a new model for this class.
- Args:
- config_class ([`PretrainedConfig`]):
- The configuration corresponding to the model to register.
- model_class ([`PreTrainedModel`]):
- The model to register.
- """
- if hasattr(model_class, "config_class") and str(model_class.config_class) != str(config_class):
- raise ValueError(
- "The model class you are passing has a `config_class` attribute that is not consistent with the "
- f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
- "one of those so they match!"
- )
- cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
- class _BaseAutoBackboneClass(_BaseAutoModelClass):
- # Base class for auto backbone models.
- _model_mapping = None
- @classmethod
- def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- requires_backends(cls, ["vision", "timm"])
- from ...models.timm_backbone import TimmBackboneConfig
- config = kwargs.pop("config", TimmBackboneConfig())
- if kwargs.get("out_features", None) is not None:
- raise ValueError("Cannot specify `out_features` for timm backbones")
- if kwargs.get("output_loading_info", False):
- raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
- num_channels = kwargs.pop("num_channels", config.num_channels)
- features_only = kwargs.pop("features_only", config.features_only)
- use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
- out_indices = kwargs.pop("out_indices", config.out_indices)
- config = TimmBackboneConfig(
- backbone=pretrained_model_name_or_path,
- num_channels=num_channels,
- features_only=features_only,
- use_pretrained_backbone=use_pretrained_backbone,
- out_indices=out_indices,
- )
- return super().from_config(config, **kwargs)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- use_timm_backbone = kwargs.pop("use_timm_backbone", False)
- if use_timm_backbone:
- return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
- return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
- def insert_head_doc(docstring, head_doc=""):
- if len(head_doc) > 0:
- return docstring.replace(
- "one of the model classes of the library ",
- f"one of the model classes of the library (with a {head_doc} head) ",
- )
- return docstring.replace(
- "one of the model classes of the library ", "one of the base model classes of the library "
- )
- def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
- # Create a new class with the right name from the base class
- model_mapping = cls._model_mapping
- name = cls.__name__
- class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc)
- cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name)
- # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't
- # have a specific docstrings for them.
- from_config = copy_func(_BaseAutoModelClass.from_config)
- from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc)
- from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
- from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
- from_config.__doc__ = from_config_docstring
- from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
- cls.from_config = classmethod(from_config)
- if name.startswith("TF"):
- from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING
- elif name.startswith("Flax"):
- from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING
- else:
- from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING
- from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
- from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc)
- from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name)
- from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
- shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
- from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
- from_pretrained.__doc__ = from_pretrained_docstring
- from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
- cls.from_pretrained = classmethod(from_pretrained)
- return cls
- def get_values(model_mapping):
- result = []
- for model in model_mapping.values():
- if isinstance(model, (list, tuple)):
- result += list(model)
- else:
- result.append(model)
- return result
- def getattribute_from_module(module, attr):
- if attr is None:
- return None
- if isinstance(attr, tuple):
- return tuple(getattribute_from_module(module, a) for a in attr)
- if hasattr(module, attr):
- return getattr(module, attr)
- # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
- # object at the top level.
- transformers_module = importlib.import_module("transformers")
- if module != transformers_module:
- try:
- return getattribute_from_module(transformers_module, attr)
- except ValueError:
- raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!")
- else:
- raise ValueError(f"Could not find {attr} in {transformers_module}!")
- def add_generation_mixin_to_remote_model(model_class):
- """
- Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.
- This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
- `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
- from the Hub may not have the `generate` method after we remove the inheritance.
- """
- # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
- if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
- return model_class
- # 2. If it already **directly** inherits from GenerationMixin, do nothing
- if "GenerationMixin" in str(model_class.__bases__):
- return model_class
- # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
- # `prepare_inputs_for_generation` method.
- has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
- has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
- if has_custom_generate or has_custom_prepare_inputs:
- model_class_with_generation_mixin = type(
- model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
- )
- return model_class_with_generation_mixin
- return model_class
- class _LazyAutoMapping(OrderedDict):
- """
- " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
- Args:
- - config_mapping: The map model type to config class
- - model_mapping: The map model type to model (or tokenizer) class
- """
- def __init__(self, config_mapping, model_mapping):
- self._config_mapping = config_mapping
- self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
- self._model_mapping = model_mapping
- self._model_mapping._model_mapping = self
- self._extra_content = {}
- self._modules = {}
- def __len__(self):
- common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
- return len(common_keys) + len(self._extra_content)
- def __getitem__(self, key):
- if key in self._extra_content:
- return self._extra_content[key]
- model_type = self._reverse_config_mapping[key.__name__]
- if model_type in self._model_mapping:
- model_name = self._model_mapping[model_type]
- return self._load_attr_from_module(model_type, model_name)
- # Maybe there was several model types associated with this config.
- model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
- for mtype in model_types:
- if mtype in self._model_mapping:
- model_name = self._model_mapping[mtype]
- return self._load_attr_from_module(mtype, model_name)
- raise KeyError(key)
- def _load_attr_from_module(self, model_type, attr):
- module_name = model_type_to_module_name(model_type)
- if module_name not in self._modules:
- self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
- return getattribute_from_module(self._modules[module_name], attr)
- def keys(self):
- mapping_keys = [
- self._load_attr_from_module(key, name)
- for key, name in self._config_mapping.items()
- if key in self._model_mapping.keys()
- ]
- return mapping_keys + list(self._extra_content.keys())
- def get(self, key, default):
- try:
- return self.__getitem__(key)
- except KeyError:
- return default
- def __bool__(self):
- return bool(self.keys())
- def values(self):
- mapping_values = [
- self._load_attr_from_module(key, name)
- for key, name in self._model_mapping.items()
- if key in self._config_mapping.keys()
- ]
- return mapping_values + list(self._extra_content.values())
- def items(self):
- mapping_items = [
- (
- self._load_attr_from_module(key, self._config_mapping[key]),
- self._load_attr_from_module(key, self._model_mapping[key]),
- )
- for key in self._model_mapping.keys()
- if key in self._config_mapping.keys()
- ]
- return mapping_items + list(self._extra_content.items())
- def __iter__(self):
- return iter(self.keys())
- def __contains__(self, item):
- if item in self._extra_content:
- return True
- if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
- return False
- model_type = self._reverse_config_mapping[item.__name__]
- return model_type in self._model_mapping
- def register(self, key, value, exist_ok=False):
- """
- Register a new model in this mapping.
- """
- if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
- model_type = self._reverse_config_mapping[key.__name__]
- if model_type in self._model_mapping.keys() and not exist_ok:
- raise ValueError(f"'{key}' is already used by a Transformers model.")
- self._extra_content[key] = value
|