__init__.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from functools import lru_cache
  17. from typing import FrozenSet
  18. from huggingface_hub import get_full_repo_name # for backward compatibility
  19. from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
  20. from packaging import version
  21. from .. import __version__
  22. from .backbone_utils import BackboneConfigMixin, BackboneMixin
  23. from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
  24. from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
  25. from .doc import (
  26. add_code_sample_docstrings,
  27. add_end_docstrings,
  28. add_start_docstrings,
  29. add_start_docstrings_to_model_forward,
  30. copy_func,
  31. replace_return_docstrings,
  32. )
  33. from .generic import (
  34. ContextManagers,
  35. ExplicitEnum,
  36. ModelOutput,
  37. PaddingStrategy,
  38. TensorType,
  39. add_model_info_to_auto_map,
  40. add_model_info_to_custom_pipelines,
  41. cached_property,
  42. can_return_loss,
  43. expand_dims,
  44. filter_out_non_signature_kwargs,
  45. find_labels,
  46. flatten_dict,
  47. infer_framework,
  48. is_jax_tensor,
  49. is_numpy_array,
  50. is_tensor,
  51. is_tf_symbolic_tensor,
  52. is_tf_tensor,
  53. is_torch_device,
  54. is_torch_dtype,
  55. is_torch_tensor,
  56. reshape,
  57. squeeze,
  58. strtobool,
  59. tensor_size,
  60. to_numpy,
  61. to_py_obj,
  62. torch_float,
  63. torch_int,
  64. transpose,
  65. working_or_temp_dir,
  66. )
  67. from .hub import (
  68. CLOUDFRONT_DISTRIB_PREFIX,
  69. HF_MODULES_CACHE,
  70. HUGGINGFACE_CO_PREFIX,
  71. HUGGINGFACE_CO_RESOLVE_ENDPOINT,
  72. PYTORCH_PRETRAINED_BERT_CACHE,
  73. PYTORCH_TRANSFORMERS_CACHE,
  74. S3_BUCKET_PREFIX,
  75. TRANSFORMERS_CACHE,
  76. TRANSFORMERS_DYNAMIC_MODULE_NAME,
  77. EntryNotFoundError,
  78. PushInProgress,
  79. PushToHubMixin,
  80. RepositoryNotFoundError,
  81. RevisionNotFoundError,
  82. cached_file,
  83. default_cache_path,
  84. define_sagemaker_information,
  85. download_url,
  86. extract_commit_hash,
  87. get_cached_models,
  88. get_file_from_repo,
  89. has_file,
  90. http_user_agent,
  91. is_offline_mode,
  92. is_remote_url,
  93. move_cache,
  94. send_example_telemetry,
  95. try_to_load_from_cache,
  96. )
  97. from .import_utils import (
  98. ACCELERATE_MIN_VERSION,
  99. ENV_VARS_TRUE_AND_AUTO_VALUES,
  100. ENV_VARS_TRUE_VALUES,
  101. GGUF_MIN_VERSION,
  102. TORCH_FX_REQUIRED_VERSION,
  103. USE_JAX,
  104. USE_TF,
  105. USE_TORCH,
  106. XLA_FSDPV2_MIN_VERSION,
  107. DummyObject,
  108. OptionalDependencyNotAvailable,
  109. _LazyModule,
  110. ccl_version,
  111. direct_transformers_import,
  112. get_torch_version,
  113. is_accelerate_available,
  114. is_apex_available,
  115. is_aqlm_available,
  116. is_auto_awq_available,
  117. is_auto_gptq_available,
  118. is_av_available,
  119. is_bitsandbytes_available,
  120. is_bitsandbytes_multi_backend_available,
  121. is_bs4_available,
  122. is_coloredlogs_available,
  123. is_compressed_tensors_available,
  124. is_cv2_available,
  125. is_cython_available,
  126. is_datasets_available,
  127. is_detectron2_available,
  128. is_eetq_available,
  129. is_essentia_available,
  130. is_faiss_available,
  131. is_fbgemm_gpu_available,
  132. is_flash_attn_2_available,
  133. is_flash_attn_greater_or_equal,
  134. is_flash_attn_greater_or_equal_2_10,
  135. is_flax_available,
  136. is_fsdp_available,
  137. is_ftfy_available,
  138. is_g2p_en_available,
  139. is_galore_torch_available,
  140. is_gguf_available,
  141. is_grokadamw_available,
  142. is_hqq_available,
  143. is_in_notebook,
  144. is_ipex_available,
  145. is_jieba_available,
  146. is_jinja_available,
  147. is_jumanpp_available,
  148. is_kenlm_available,
  149. is_keras_nlp_available,
  150. is_levenshtein_available,
  151. is_librosa_available,
  152. is_liger_kernel_available,
  153. is_lomo_available,
  154. is_mlx_available,
  155. is_natten_available,
  156. is_ninja_available,
  157. is_nltk_available,
  158. is_onnx_available,
  159. is_openai_available,
  160. is_optimum_available,
  161. is_optimum_quanto_available,
  162. is_pandas_available,
  163. is_peft_available,
  164. is_phonemizer_available,
  165. is_pretty_midi_available,
  166. is_protobuf_available,
  167. is_psutil_available,
  168. is_py3nvml_available,
  169. is_pyctcdecode_available,
  170. is_pytesseract_available,
  171. is_pytest_available,
  172. is_pytorch_quantization_available,
  173. is_quanto_available,
  174. is_rjieba_available,
  175. is_sacremoses_available,
  176. is_safetensors_available,
  177. is_sagemaker_dp_enabled,
  178. is_sagemaker_mp_enabled,
  179. is_schedulefree_available,
  180. is_scipy_available,
  181. is_sentencepiece_available,
  182. is_seqio_available,
  183. is_sklearn_available,
  184. is_soundfile_availble,
  185. is_spacy_available,
  186. is_speech_available,
  187. is_sudachi_available,
  188. is_sudachi_projection_available,
  189. is_tensorflow_probability_available,
  190. is_tensorflow_text_available,
  191. is_tf2onnx_available,
  192. is_tf_available,
  193. is_tiktoken_available,
  194. is_timm_available,
  195. is_tokenizers_available,
  196. is_torch_available,
  197. is_torch_bf16_available,
  198. is_torch_bf16_available_on_device,
  199. is_torch_bf16_cpu_available,
  200. is_torch_bf16_gpu_available,
  201. is_torch_compile_available,
  202. is_torch_cuda_available,
  203. is_torch_deterministic,
  204. is_torch_fp16_available_on_device,
  205. is_torch_fx_available,
  206. is_torch_fx_proxy,
  207. is_torch_mlu_available,
  208. is_torch_mps_available,
  209. is_torch_musa_available,
  210. is_torch_neuroncore_available,
  211. is_torch_npu_available,
  212. is_torch_sdpa_available,
  213. is_torch_tensorrt_fx_available,
  214. is_torch_tf32_available,
  215. is_torch_tpu_available,
  216. is_torch_xla_available,
  217. is_torch_xpu_available,
  218. is_torchao_available,
  219. is_torchaudio_available,
  220. is_torchdistx_available,
  221. is_torchdynamo_available,
  222. is_torchdynamo_compiling,
  223. is_torchvision_available,
  224. is_torchvision_v2_available,
  225. is_training_run_on_sagemaker,
  226. is_uroman_available,
  227. is_vision_available,
  228. requires_backends,
  229. torch_only_method,
  230. )
  231. from .peft_utils import (
  232. ADAPTER_CONFIG_NAME,
  233. ADAPTER_SAFE_WEIGHTS_NAME,
  234. ADAPTER_WEIGHTS_NAME,
  235. check_peft_version,
  236. find_adapter_config_file,
  237. )
  238. WEIGHTS_NAME = "pytorch_model.bin"
  239. WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
  240. TF2_WEIGHTS_NAME = "tf_model.h5"
  241. TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
  242. TF_WEIGHTS_NAME = "model.ckpt"
  243. FLAX_WEIGHTS_NAME = "flax_model.msgpack"
  244. FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
  245. SAFE_WEIGHTS_NAME = "model.safetensors"
  246. SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
  247. CONFIG_NAME = "config.json"
  248. FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
  249. IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
  250. PROCESSOR_NAME = "processor_config.json"
  251. CHAT_TEMPLATE_NAME = "chat_template.json"
  252. GENERATION_CONFIG_NAME = "generation_config.json"
  253. MODEL_CARD_NAME = "modelcard.json"
  254. SENTENCEPIECE_UNDERLINE = "▁"
  255. SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
  256. MULTIPLE_CHOICE_DUMMY_INPUTS = [
  257. [[0, 1, 0, 1], [1, 0, 0, 1]]
  258. ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
  259. DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
  260. DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
  261. def check_min_version(min_version):
  262. if version.parse(__version__) < version.parse(min_version):
  263. if "dev" in min_version:
  264. error_message = (
  265. "This example requires a source install from HuggingFace Transformers (see "
  266. "`https://huggingface.co/docs/transformers/installation#install-from-source`),"
  267. )
  268. else:
  269. error_message = f"This example requires a minimum version of {min_version},"
  270. error_message += f" but the version found is {__version__}.\n"
  271. raise ImportError(
  272. error_message
  273. + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
  274. "versions of HuggingFace Transformers."
  275. )
  276. @lru_cache()
  277. def get_available_devices() -> FrozenSet[str]:
  278. """
  279. Returns a frozenset of devices available for the current PyTorch installation.
  280. """
  281. devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
  282. if is_torch_cuda_available():
  283. devices.add("cuda")
  284. if is_torch_mps_available():
  285. devices.add("mps")
  286. if is_torch_xpu_available():
  287. devices.add("xpu")
  288. if is_torch_npu_available():
  289. devices.add("npu")
  290. if is_torch_mlu_available():
  291. devices.add("mlu")
  292. if is_torch_musa_available():
  293. devices.add("musa")
  294. return frozenset(devices)