| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- # Copyright 2023 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib
- import os
- from typing import Dict, Optional, Union
- from packaging import version
- from .hub import cached_file
- from .import_utils import is_peft_available
- ADAPTER_CONFIG_NAME = "adapter_config.json"
- ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
- ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
- def find_adapter_config_file(
- model_id: str,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- force_download: bool = False,
- resume_download: Optional[bool] = None,
- proxies: Optional[Dict[str, str]] = None,
- token: Optional[Union[bool, str]] = None,
- revision: Optional[str] = None,
- local_files_only: bool = False,
- subfolder: str = "",
- _commit_hash: Optional[str] = None,
- ) -> Optional[str]:
- r"""
- Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
- config file if it is, None otherwise.
- Args:
- model_id (`str`):
- The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
- cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
- exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- <Tip>
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
- </Tip>
- local_files_only (`bool`, *optional*, defaults to `False`):
- If `True`, will only try to load the tokenizer configuration from local files.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- """
- adapter_cached_filename = None
- if model_id is None:
- return None
- elif os.path.isdir(model_id):
- list_remote_files = os.listdir(model_id)
- if ADAPTER_CONFIG_NAME in list_remote_files:
- adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
- else:
- adapter_cached_filename = cached_file(
- model_id,
- ADAPTER_CONFIG_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- token=token,
- revision=revision,
- local_files_only=local_files_only,
- subfolder=subfolder,
- _commit_hash=_commit_hash,
- _raise_exceptions_for_gated_repo=False,
- _raise_exceptions_for_missing_entries=False,
- _raise_exceptions_for_connection_errors=False,
- )
- return adapter_cached_filename
- def check_peft_version(min_version: str) -> None:
- r"""
- Checks if the version of PEFT is compatible.
- Args:
- version (`str`):
- The version of PEFT to check against.
- """
- if not is_peft_available():
- raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
- is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)
- if not is_peft_version_compatible:
- raise ValueError(
- f"The version of PEFT you are using is not compatible, please use a version that is greater"
- f" than {min_version}"
- )
|