| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684 |
- # coding=utf-8
- # Copyright 2021 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utilities to dynamically load objects from the Hub."""
- import filecmp
- import hashlib
- import importlib
- import importlib.util
- import os
- import re
- import shutil
- import signal
- import sys
- import threading
- import typing
- import warnings
- from pathlib import Path
- from types import ModuleType
- from typing import Any, Dict, List, Optional, Union
- from huggingface_hub import try_to_load_from_cache
- from .utils import (
- HF_MODULES_CACHE,
- TRANSFORMERS_DYNAMIC_MODULE_NAME,
- cached_file,
- extract_commit_hash,
- is_offline_mode,
- logging,
- )
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
- _HF_REMOTE_CODE_LOCK = threading.Lock()
- def init_hf_modules():
- """
- Creates the cache directory for modules with an init, and adds it to the Python path.
- """
- # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
- if HF_MODULES_CACHE in sys.path:
- return
- sys.path.append(HF_MODULES_CACHE)
- os.makedirs(HF_MODULES_CACHE, exist_ok=True)
- init_path = Path(HF_MODULES_CACHE) / "__init__.py"
- if not init_path.exists():
- init_path.touch()
- importlib.invalidate_caches()
- def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
- """
- Creates a dynamic module in the cache directory for modules.
- Args:
- name (`str` or `os.PathLike`):
- The name of the dynamic module to create.
- """
- init_hf_modules()
- dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
- # If the parent module does not exist yet, recursively create it.
- if not dynamic_module_path.parent.exists():
- create_dynamic_module(dynamic_module_path.parent)
- os.makedirs(dynamic_module_path, exist_ok=True)
- init_path = dynamic_module_path / "__init__.py"
- if not init_path.exists():
- init_path.touch()
- # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
- # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
- importlib.invalidate_caches()
- def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
- """
- Get the list of modules that are relatively imported in a module file.
- Args:
- module_file (`str` or `os.PathLike`): The module file to inspect.
- Returns:
- `List[str]`: The list of relative imports in the module.
- """
- with open(module_file, "r", encoding="utf-8") as f:
- content = f.read()
- # Imports of the form `import .xxx`
- relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
- # Imports of the form `from .xxx import yyy`
- relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
- # Unique-ify
- return list(set(relative_imports))
- def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
- """
- Get the list of all files that are needed for a given module. Note that this function recurses through the relative
- imports (if a imports b and b imports c, it will return module files for b and c).
- Args:
- module_file (`str` or `os.PathLike`): The module file to inspect.
- Returns:
- `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
- of module files a given module needs.
- """
- no_change = False
- files_to_check = [module_file]
- all_relative_imports = []
- # Let's recurse through all relative imports
- while not no_change:
- new_imports = []
- for f in files_to_check:
- new_imports.extend(get_relative_imports(f))
- module_path = Path(module_file).parent
- new_import_files = [str(module_path / m) for m in new_imports]
- new_import_files = [f for f in new_import_files if f not in all_relative_imports]
- files_to_check = [f"{f}.py" for f in new_import_files]
- no_change = len(new_import_files) == 0
- all_relative_imports.extend(files_to_check)
- return all_relative_imports
- def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
- """
- Extracts all the libraries (not relative imports this time) that are imported in a file.
- Args:
- filename (`str` or `os.PathLike`): The module file to inspect.
- Returns:
- `List[str]`: The list of all packages required to use the input module.
- """
- with open(filename, "r", encoding="utf-8") as f:
- content = f.read()
- # filter out try/except block so in custom code we can have try/except imports
- content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
- # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
- content = re.sub(
- r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
- )
- # Imports of the form `import xxx`
- imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
- # Imports of the form `from xxx import yyy`
- imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
- # Only keep the top-level module
- imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
- return list(set(imports))
- def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
- """
- Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
- library is missing.
- Args:
- filename (`str` or `os.PathLike`): The module file to check.
- Returns:
- `List[str]`: The list of relative imports in the file.
- """
- imports = get_imports(filename)
- missing_packages = []
- for imp in imports:
- try:
- importlib.import_module(imp)
- except ImportError as exception:
- logger.warning(f"Encountered exception while importing {imp}: {exception}")
- # Some packages can fail with an ImportError because of a dependency issue.
- # This check avoids hiding such errors.
- # See https://github.com/huggingface/transformers/issues/33604
- if "No module named" in str(exception):
- missing_packages.append(imp)
- else:
- raise
- if len(missing_packages) > 0:
- raise ImportError(
- "This modeling file requires the following packages that were not found in your environment: "
- f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
- )
- return get_relative_imports(filename)
- def get_class_in_module(
- class_name: str,
- module_path: Union[str, os.PathLike],
- *,
- force_reload: bool = False,
- ) -> typing.Type:
- """
- Import a module on the cache directory for modules and extract a class from it.
- Args:
- class_name (`str`): The name of the class to import.
- module_path (`str` or `os.PathLike`): The path to the module to import.
- force_reload (`bool`, *optional*, defaults to `False`):
- Whether to reload the dynamic module from file if it already exists in `sys.modules`.
- Otherwise, the module is only reloaded if the file has changed.
- Returns:
- `typing.Type`: The class looked for.
- """
- name = os.path.normpath(module_path)
- if name.endswith(".py"):
- name = name[:-3]
- name = name.replace(os.path.sep, ".")
- module_file: Path = Path(HF_MODULES_CACHE) / module_path
- with _HF_REMOTE_CODE_LOCK:
- if force_reload:
- sys.modules.pop(name, None)
- importlib.invalidate_caches()
- cached_module: Optional[ModuleType] = sys.modules.get(name)
- module_spec = importlib.util.spec_from_file_location(name, location=module_file)
- # Hash the module file and all its relative imports to check if we need to reload it
- module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
- module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
- module: ModuleType
- if cached_module is None:
- module = importlib.util.module_from_spec(module_spec)
- # insert it into sys.modules before any loading begins
- sys.modules[name] = module
- else:
- module = cached_module
- # reload in both cases, unless the module is already imported and the hash hits
- if getattr(module, "__transformers_module_hash__", "") != module_hash:
- module_spec.loader.exec_module(module)
- module.__transformers_module_hash__ = module_hash
- return getattr(module, class_name)
- def get_cached_module_file(
- pretrained_model_name_or_path: Union[str, os.PathLike],
- module_file: str,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- resume_download: Optional[bool] = None,
- proxies: Optional[Dict[str, str]] = None,
- token: Optional[Union[bool, str]] = None,
- revision: Optional[str] = None,
- local_files_only: bool = False,
- repo_type: Optional[str] = None,
- _commit_hash: Optional[str] = None,
- **deprecated_kwargs,
- ) -> str:
- """
- Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
- Transformers module.
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a configuration file saved using the
- [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
- module_file (`str`):
- The name of the module file containing the class to look for.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
- cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
- exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, will only try to load the tokenizer configuration from local files.
- repo_type (`str`, *optional*):
- Specify the repo type (useful when downloading from a space for instance).
- <Tip>
- Passing `token=True` is required when you want to use a private model.
- </Tip>
- Returns:
- `str`: The path to the module inside the cache.
- """
- use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
- token = use_auth_token
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
- # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if is_local:
- submodule = os.path.basename(pretrained_model_name_or_path)
- else:
- submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
- cached_module = try_to_load_from_cache(
- pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
- )
- new_files = []
- try:
- # Load from URL or cache if already cached
- resolved_module_file = cached_file(
- pretrained_model_name_or_path,
- module_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- repo_type=repo_type,
- _commit_hash=_commit_hash,
- )
- if not is_local and cached_module != resolved_module_file:
- new_files.append(module_file)
- except EnvironmentError:
- logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
- raise
- # Check we have all the requirements in our environment
- modules_needed = check_imports(resolved_module_file)
- # Now we move the module inside our cached dynamic modules.
- full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
- create_dynamic_module(full_submodule)
- submodule_path = Path(HF_MODULES_CACHE) / full_submodule
- if submodule == os.path.basename(pretrained_model_name_or_path):
- # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
- # has changed since last copy.
- if not (submodule_path / module_file).exists() or not filecmp.cmp(
- resolved_module_file, str(submodule_path / module_file)
- ):
- shutil.copy(resolved_module_file, submodule_path / module_file)
- importlib.invalidate_caches()
- for module_needed in modules_needed:
- module_needed = f"{module_needed}.py"
- module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
- if not (submodule_path / module_needed).exists() or not filecmp.cmp(
- module_needed_file, str(submodule_path / module_needed)
- ):
- shutil.copy(module_needed_file, submodule_path / module_needed)
- importlib.invalidate_caches()
- else:
- # Get the commit hash
- commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
- # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
- # benefit of versioning.
- submodule_path = submodule_path / commit_hash
- full_submodule = full_submodule + os.path.sep + commit_hash
- create_dynamic_module(full_submodule)
- if not (submodule_path / module_file).exists():
- shutil.copy(resolved_module_file, submodule_path / module_file)
- importlib.invalidate_caches()
- # Make sure we also have every file with relative
- for module_needed in modules_needed:
- if not (submodule_path / f"{module_needed}.py").exists():
- get_cached_module_file(
- pretrained_model_name_or_path,
- f"{module_needed}.py",
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- token=token,
- revision=revision,
- local_files_only=local_files_only,
- _commit_hash=commit_hash,
- )
- new_files.append(f"{module_needed}.py")
- if len(new_files) > 0 and revision is None:
- new_files = "\n".join([f"- {f}" for f in new_files])
- repo_type_str = "" if repo_type is None else f"{repo_type}s/"
- url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
- logger.warning(
- f"A new version of the following files was downloaded from {url}:\n{new_files}"
- "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
- "versions of the code file, you can pin a revision."
- )
- return os.path.join(full_submodule, module_file)
- def get_class_from_dynamic_module(
- class_reference: str,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- resume_download: Optional[bool] = None,
- proxies: Optional[Dict[str, str]] = None,
- token: Optional[Union[bool, str]] = None,
- revision: Optional[str] = None,
- local_files_only: bool = False,
- repo_type: Optional[str] = None,
- code_revision: Optional[str] = None,
- **kwargs,
- ) -> typing.Type:
- """
- Extracts a class from a module file, present in the local folder or repository of a model.
- <Tip warning={true}>
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
- therefore only be called on trusted repos.
- </Tip>
- Args:
- class_reference (`str`):
- The full name of the class to load, including its module and optionally its repo.
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a configuration file saved using the
- [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
- This is used when `class_reference` does not specify another repo.
- module_file (`str`):
- The name of the module file containing the class to look for.
- class_name (`str`):
- The name of the class to import in the module.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
- cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
- exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, will only try to load the tokenizer configuration from local files.
- repo_type (`str`, *optional*):
- Specify the repo type (useful when downloading from a space for instance).
- code_revision (`str`, *optional*, defaults to `"main"`):
- The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
- rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
- storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
- <Tip>
- Passing `token=True` is required when you want to use a private model.
- </Tip>
- Returns:
- `typing.Type`: The class, dynamically imported from the module.
- Examples:
- ```python
- # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
- # module.
- cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
- # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
- # module.
- cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
- ```"""
- use_auth_token = kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
- token = use_auth_token
- # Catch the name of the repo if it's specified in `class_reference`
- if "--" in class_reference:
- repo_id, class_reference = class_reference.split("--")
- else:
- repo_id = pretrained_model_name_or_path
- module_file, class_name = class_reference.split(".")
- if code_revision is None and pretrained_model_name_or_path == repo_id:
- code_revision = revision
- # And lastly we get the class inside our newly created module
- final_module = get_cached_module_file(
- repo_id,
- module_file + ".py",
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- token=token,
- revision=code_revision,
- local_files_only=local_files_only,
- repo_type=repo_type,
- )
- return get_class_in_module(class_name, final_module, force_reload=force_download)
- def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
- """
- Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
- adds the proper fields in a config.
- Args:
- obj (`Any`): The object for which to save the module files.
- folder (`str` or `os.PathLike`): The folder where to save.
- config (`PretrainedConfig` or dictionary, `optional`):
- A config in which to register the auto_map corresponding to this custom object.
- Returns:
- `List[str]`: The list of files saved.
- """
- if obj.__module__ == "__main__":
- logger.warning(
- f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
- "this code in a separate module so we can include it in the saved folder and make it easier to share via "
- "the Hub."
- )
- return
- def _set_auto_map_in_config(_config):
- module_name = obj.__class__.__module__
- last_module = module_name.split(".")[-1]
- full_name = f"{last_module}.{obj.__class__.__name__}"
- # Special handling for tokenizers
- if "Tokenizer" in full_name:
- slow_tokenizer_class = None
- fast_tokenizer_class = None
- if obj.__class__.__name__.endswith("Fast"):
- # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
- fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
- if getattr(obj, "slow_tokenizer_class", None) is not None:
- slow_tokenizer = getattr(obj, "slow_tokenizer_class")
- slow_tok_module_name = slow_tokenizer.__module__
- last_slow_tok_module = slow_tok_module_name.split(".")[-1]
- slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
- else:
- # Slow tokenizer: no way to have the fast class
- slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
- full_name = (slow_tokenizer_class, fast_tokenizer_class)
- if isinstance(_config, dict):
- auto_map = _config.get("auto_map", {})
- auto_map[obj._auto_class] = full_name
- _config["auto_map"] = auto_map
- elif getattr(_config, "auto_map", None) is not None:
- _config.auto_map[obj._auto_class] = full_name
- else:
- _config.auto_map = {obj._auto_class: full_name}
- # Add object class to the config auto_map
- if isinstance(config, (list, tuple)):
- for cfg in config:
- _set_auto_map_in_config(cfg)
- elif config is not None:
- _set_auto_map_in_config(config)
- result = []
- # Copy module file to the output folder.
- object_file = sys.modules[obj.__module__].__file__
- dest_file = Path(folder) / (Path(object_file).name)
- shutil.copy(object_file, dest_file)
- result.append(dest_file)
- # Gather all relative imports recursively and make sure they are copied as well.
- for needed_file in get_relative_import_files(object_file):
- dest_file = Path(folder) / (Path(needed_file).name)
- shutil.copy(needed_file, dest_file)
- result.append(dest_file)
- return result
- def _raise_timeout_error(signum, frame):
- raise ValueError(
- "Loading this model requires you to execute custom code contained in the model repository on your local "
- "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
- )
- TIME_OUT_REMOTE_CODE = 15
- def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
- if trust_remote_code is None:
- if has_local_code:
- trust_remote_code = False
- elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
- prev_sig_handler = None
- try:
- prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
- signal.alarm(TIME_OUT_REMOTE_CODE)
- while trust_remote_code is None:
- answer = input(
- f"The repository for {model_name} contains custom code which must be executed to correctly "
- f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
- f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
- f"Do you wish to run the custom code? [y/N] "
- )
- if answer.lower() in ["yes", "y", "1"]:
- trust_remote_code = True
- elif answer.lower() in ["no", "n", "0", ""]:
- trust_remote_code = False
- signal.alarm(0)
- except Exception:
- # OS which does not support signal.SIGALRM
- raise ValueError(
- f"The repository for {model_name} contains custom code which must be executed to correctly "
- f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
- f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
- )
- finally:
- if prev_sig_handler is not None:
- signal.signal(signal.SIGALRM, prev_sig_handler)
- signal.alarm(0)
- elif has_remote_code:
- # For the CI which puts the timeout at 0
- _raise_timeout_error(None, None)
- if has_remote_code and not has_local_code and not trust_remote_code:
- raise ValueError(
- f"Loading {model_name} requires you to execute the configuration file in that"
- " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
- " set the option `trust_remote_code=True` to remove this error."
- )
- return trust_remote_code
|