peft_utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import os
  16. from typing import Dict, Optional, Union
  17. from packaging import version
  18. from .hub import cached_file
  19. from .import_utils import is_peft_available
  20. ADAPTER_CONFIG_NAME = "adapter_config.json"
  21. ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
  22. ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
  23. def find_adapter_config_file(
  24. model_id: str,
  25. cache_dir: Optional[Union[str, os.PathLike]] = None,
  26. force_download: bool = False,
  27. resume_download: Optional[bool] = None,
  28. proxies: Optional[Dict[str, str]] = None,
  29. token: Optional[Union[bool, str]] = None,
  30. revision: Optional[str] = None,
  31. local_files_only: bool = False,
  32. subfolder: str = "",
  33. _commit_hash: Optional[str] = None,
  34. ) -> Optional[str]:
  35. r"""
  36. Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
  37. config file if it is, None otherwise.
  38. Args:
  39. model_id (`str`):
  40. The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
  41. cache_dir (`str` or `os.PathLike`, *optional*):
  42. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  43. cache should not be used.
  44. force_download (`bool`, *optional*, defaults to `False`):
  45. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  46. exist.
  47. resume_download:
  48. Deprecated and ignored. All downloads are now resumed by default when possible.
  49. Will be removed in v5 of Transformers.
  50. proxies (`Dict[str, str]`, *optional*):
  51. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  52. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  53. token (`str` or *bool*, *optional*):
  54. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  55. when running `huggingface-cli login` (stored in `~/.huggingface`).
  56. revision (`str`, *optional*, defaults to `"main"`):
  57. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  58. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  59. identifier allowed by git.
  60. <Tip>
  61. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
  62. </Tip>
  63. local_files_only (`bool`, *optional*, defaults to `False`):
  64. If `True`, will only try to load the tokenizer configuration from local files.
  65. subfolder (`str`, *optional*, defaults to `""`):
  66. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  67. specify the folder name here.
  68. """
  69. adapter_cached_filename = None
  70. if model_id is None:
  71. return None
  72. elif os.path.isdir(model_id):
  73. list_remote_files = os.listdir(model_id)
  74. if ADAPTER_CONFIG_NAME in list_remote_files:
  75. adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
  76. else:
  77. adapter_cached_filename = cached_file(
  78. model_id,
  79. ADAPTER_CONFIG_NAME,
  80. cache_dir=cache_dir,
  81. force_download=force_download,
  82. resume_download=resume_download,
  83. proxies=proxies,
  84. token=token,
  85. revision=revision,
  86. local_files_only=local_files_only,
  87. subfolder=subfolder,
  88. _commit_hash=_commit_hash,
  89. _raise_exceptions_for_gated_repo=False,
  90. _raise_exceptions_for_missing_entries=False,
  91. _raise_exceptions_for_connection_errors=False,
  92. )
  93. return adapter_cached_filename
  94. def check_peft_version(min_version: str) -> None:
  95. r"""
  96. Checks if the version of PEFT is compatible.
  97. Args:
  98. version (`str`):
  99. The version of PEFT to check against.
  100. """
  101. if not is_peft_available():
  102. raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
  103. is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)
  104. if not is_peft_version_compatible:
  105. raise ValueError(
  106. f"The version of PEFT you are using is not compatible, please use a version that is greater"
  107. f" than {min_version}"
  108. )