hub.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Hub utilities: utilities related to download and cache models
  16. """
  17. import json
  18. import os
  19. import re
  20. import shutil
  21. import sys
  22. import tempfile
  23. import traceback
  24. import warnings
  25. from concurrent import futures
  26. from pathlib import Path
  27. from typing import Dict, List, Optional, Tuple, Union
  28. from urllib.parse import urlparse
  29. from uuid import uuid4
  30. import huggingface_hub
  31. import requests
  32. from huggingface_hub import (
  33. _CACHED_NO_EXIST,
  34. CommitOperationAdd,
  35. ModelCard,
  36. ModelCardData,
  37. constants,
  38. create_branch,
  39. create_commit,
  40. create_repo,
  41. get_hf_file_metadata,
  42. hf_hub_download,
  43. hf_hub_url,
  44. try_to_load_from_cache,
  45. )
  46. from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
  47. from huggingface_hub.utils import (
  48. EntryNotFoundError,
  49. GatedRepoError,
  50. HfHubHTTPError,
  51. HFValidationError,
  52. LocalEntryNotFoundError,
  53. OfflineModeIsEnabled,
  54. RepositoryNotFoundError,
  55. RevisionNotFoundError,
  56. build_hf_headers,
  57. get_session,
  58. hf_raise_for_status,
  59. send_telemetry,
  60. )
  61. from huggingface_hub.utils._deprecation import _deprecate_method
  62. from requests.exceptions import HTTPError
  63. from . import __version__, logging
  64. from .generic import working_or_temp_dir
  65. from .import_utils import (
  66. ENV_VARS_TRUE_VALUES,
  67. _tf_version,
  68. _torch_version,
  69. is_tf_available,
  70. is_torch_available,
  71. is_training_run_on_sagemaker,
  72. )
  73. from .logging import tqdm
  74. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  75. _is_offline_mode = huggingface_hub.constants.HF_HUB_OFFLINE
  76. def is_offline_mode():
  77. return _is_offline_mode
  78. torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
  79. default_cache_path = constants.default_cache_path
  80. old_default_cache_path = os.path.join(torch_cache_home, "transformers")
  81. # Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
  82. # The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
  83. # documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
  84. #
  85. # In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed
  86. # to be set to the right value.
  87. #
  88. # TODO: clean this for v5?
  89. PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
  90. PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
  91. TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
  92. # Onetime move from the old location to the new one if no ENV variable has been set.
  93. if (
  94. os.path.isdir(old_default_cache_path)
  95. and not os.path.isdir(constants.HF_HUB_CACHE)
  96. and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
  97. and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
  98. and "TRANSFORMERS_CACHE" not in os.environ
  99. ):
  100. logger.warning(
  101. "In Transformers v4.22.0, the default path to cache downloaded models changed from"
  102. " '~/.cache/torch/transformers' to '~/.cache/huggingface/hub'. Since you don't seem to have"
  103. " overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
  104. " '~/.cache/huggingface/hub' to avoid redownloading models you have already in the cache. You should"
  105. " only see this message once."
  106. )
  107. shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)
  108. HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
  109. TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
  110. SESSION_ID = uuid4().hex
  111. # Add deprecation warning for old environment variables.
  112. for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
  113. if os.getenv(key) is not None:
  114. warnings.warn(
  115. f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
  116. FutureWarning,
  117. )
  118. S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
  119. CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
  120. _staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
  121. _default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"
  122. HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
  123. if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
  124. warnings.warn(
  125. "Using the environment variable `HUGGINGFACE_CO_RESOLVE_ENDPOINT` is deprecated and will be removed in "
  126. "Transformers v5. Use `HF_ENDPOINT` instead.",
  127. FutureWarning,
  128. )
  129. HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
  130. HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
  131. HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
  132. HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
  133. def _get_cache_file_to_return(
  134. path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None
  135. ):
  136. # We try to see if we have a cached version (not up to date):
  137. resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
  138. if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
  139. return resolved_file
  140. return None
  141. def is_remote_url(url_or_filename):
  142. parsed = urlparse(url_or_filename)
  143. return parsed.scheme in ("http", "https")
  144. # TODO: remove this once fully deprecated
  145. # TODO? remove from './examples/research_projects/lxmert/utils.py' as well
  146. # TODO? remove from './examples/research_projects/visual_bert/utils.py' as well
  147. @_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
  148. def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
  149. """
  150. Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
  151. etag, size_MB)`. Filenames in `cache_dir` are use to get the metadata for each model, only urls ending with *.bin*
  152. are added.
  153. Args:
  154. cache_dir (`Union[str, Path]`, *optional*):
  155. The cache directory to search for models within. Will default to the transformers cache if unset.
  156. Returns:
  157. List[Tuple]: List of tuples each with shape `(model_url, etag, size_MB)`
  158. """
  159. if cache_dir is None:
  160. cache_dir = TRANSFORMERS_CACHE
  161. elif isinstance(cache_dir, Path):
  162. cache_dir = str(cache_dir)
  163. if not os.path.isdir(cache_dir):
  164. return []
  165. cached_models = []
  166. for file in os.listdir(cache_dir):
  167. if file.endswith(".json"):
  168. meta_path = os.path.join(cache_dir, file)
  169. with open(meta_path, encoding="utf-8") as meta_file:
  170. metadata = json.load(meta_file)
  171. url = metadata["url"]
  172. etag = metadata["etag"]
  173. if url.endswith(".bin"):
  174. size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
  175. cached_models.append((url, etag, size_MB))
  176. return cached_models
  177. def define_sagemaker_information():
  178. try:
  179. instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
  180. dlc_container_used = instance_data["Image"]
  181. dlc_tag = instance_data["Image"].split(":")[1]
  182. except Exception:
  183. dlc_container_used = None
  184. dlc_tag = None
  185. sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
  186. runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
  187. account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None
  188. sagemaker_object = {
  189. "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
  190. "sm_region": os.getenv("AWS_REGION", None),
  191. "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
  192. "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
  193. "sm_distributed_training": runs_distributed_training,
  194. "sm_deep_learning_container": dlc_container_used,
  195. "sm_deep_learning_container_tag": dlc_tag,
  196. "sm_account_id": account_id,
  197. }
  198. return sagemaker_object
  199. def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
  200. """
  201. Formats a user-agent string with basic info about a request.
  202. """
  203. ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
  204. if is_torch_available():
  205. ua += f"; torch/{_torch_version}"
  206. if is_tf_available():
  207. ua += f"; tensorflow/{_tf_version}"
  208. if constants.HF_HUB_DISABLE_TELEMETRY:
  209. return ua + "; telemetry/off"
  210. if is_training_run_on_sagemaker():
  211. ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
  212. # CI will set this value to True
  213. if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
  214. ua += "; is_ci/true"
  215. if isinstance(user_agent, dict):
  216. ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
  217. elif isinstance(user_agent, str):
  218. ua += "; " + user_agent
  219. return ua
  220. def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
  221. """
  222. Extracts the commit hash from a resolved filename toward a cache file.
  223. """
  224. if resolved_file is None or commit_hash is not None:
  225. return commit_hash
  226. resolved_file = str(Path(resolved_file).as_posix())
  227. search = re.search(r"snapshots/([^/]+)/", resolved_file)
  228. if search is None:
  229. return None
  230. commit_hash = search.groups()[0]
  231. return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
  232. def cached_file(
  233. path_or_repo_id: Union[str, os.PathLike],
  234. filename: str,
  235. cache_dir: Optional[Union[str, os.PathLike]] = None,
  236. force_download: bool = False,
  237. resume_download: Optional[bool] = None,
  238. proxies: Optional[Dict[str, str]] = None,
  239. token: Optional[Union[bool, str]] = None,
  240. revision: Optional[str] = None,
  241. local_files_only: bool = False,
  242. subfolder: str = "",
  243. repo_type: Optional[str] = None,
  244. user_agent: Optional[Union[str, Dict[str, str]]] = None,
  245. _raise_exceptions_for_gated_repo: bool = True,
  246. _raise_exceptions_for_missing_entries: bool = True,
  247. _raise_exceptions_for_connection_errors: bool = True,
  248. _commit_hash: Optional[str] = None,
  249. **deprecated_kwargs,
  250. ) -> Optional[str]:
  251. """
  252. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
  253. Args:
  254. path_or_repo_id (`str` or `os.PathLike`):
  255. This can be either:
  256. - a string, the *model id* of a model repo on huggingface.co.
  257. - a path to a *directory* potentially containing the file.
  258. filename (`str`):
  259. The name of the file to locate in `path_or_repo`.
  260. cache_dir (`str` or `os.PathLike`, *optional*):
  261. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  262. cache should not be used.
  263. force_download (`bool`, *optional*, defaults to `False`):
  264. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  265. exist.
  266. resume_download:
  267. Deprecated and ignored. All downloads are now resumed by default when possible.
  268. Will be removed in v5 of Transformers.
  269. proxies (`Dict[str, str]`, *optional*):
  270. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  271. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  272. token (`str` or *bool*, *optional*):
  273. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  274. when running `huggingface-cli login` (stored in `~/.huggingface`).
  275. revision (`str`, *optional*, defaults to `"main"`):
  276. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  277. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  278. identifier allowed by git.
  279. local_files_only (`bool`, *optional*, defaults to `False`):
  280. If `True`, will only try to load the tokenizer configuration from local files.
  281. subfolder (`str`, *optional*, defaults to `""`):
  282. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  283. specify the folder name here.
  284. repo_type (`str`, *optional*):
  285. Specify the repo type (useful when downloading from a space for instance).
  286. <Tip>
  287. Passing `token=True` is required when you want to use a private model.
  288. </Tip>
  289. Returns:
  290. `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
  291. Examples:
  292. ```python
  293. # Download a model weight from the Hub and cache it.
  294. model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
  295. ```
  296. """
  297. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  298. if use_auth_token is not None:
  299. warnings.warn(
  300. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  301. FutureWarning,
  302. )
  303. if token is not None:
  304. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  305. token = use_auth_token
  306. # Private arguments
  307. # _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
  308. # None.
  309. # _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
  310. # None.
  311. # _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
  312. # None.
  313. # _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
  314. # a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
  315. if is_offline_mode() and not local_files_only:
  316. logger.info("Offline mode: forcing local_files_only=True")
  317. local_files_only = True
  318. if subfolder is None:
  319. subfolder = ""
  320. path_or_repo_id = str(path_or_repo_id)
  321. full_filename = os.path.join(subfolder, filename)
  322. if os.path.isdir(path_or_repo_id):
  323. resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
  324. if not os.path.isfile(resolved_file):
  325. if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]:
  326. raise EnvironmentError(
  327. f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
  328. f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
  329. )
  330. else:
  331. return None
  332. return resolved_file
  333. if cache_dir is None:
  334. cache_dir = TRANSFORMERS_CACHE
  335. if isinstance(cache_dir, Path):
  336. cache_dir = str(cache_dir)
  337. if _commit_hash is not None and not force_download:
  338. # If the file is cached under that commit hash, we return it directly.
  339. resolved_file = try_to_load_from_cache(
  340. path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
  341. )
  342. if resolved_file is not None:
  343. if resolved_file is not _CACHED_NO_EXIST:
  344. return resolved_file
  345. elif not _raise_exceptions_for_missing_entries:
  346. return None
  347. else:
  348. raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
  349. user_agent = http_user_agent(user_agent)
  350. try:
  351. # Load from URL or cache if already cached
  352. resolved_file = hf_hub_download(
  353. path_or_repo_id,
  354. filename,
  355. subfolder=None if len(subfolder) == 0 else subfolder,
  356. repo_type=repo_type,
  357. revision=revision,
  358. cache_dir=cache_dir,
  359. user_agent=user_agent,
  360. force_download=force_download,
  361. proxies=proxies,
  362. resume_download=resume_download,
  363. token=token,
  364. local_files_only=local_files_only,
  365. )
  366. except GatedRepoError as e:
  367. resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
  368. if resolved_file is not None or not _raise_exceptions_for_gated_repo:
  369. return resolved_file
  370. raise EnvironmentError(
  371. "You are trying to access a gated repo.\nMake sure to have access to it at "
  372. f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
  373. ) from e
  374. except RepositoryNotFoundError as e:
  375. raise EnvironmentError(
  376. f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
  377. "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
  378. "having permission to this repo either by logging in with `huggingface-cli login` or by passing "
  379. "`token=<your_token>`"
  380. ) from e
  381. except RevisionNotFoundError as e:
  382. raise EnvironmentError(
  383. f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
  384. "for this model name. Check the model page at "
  385. f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
  386. ) from e
  387. except LocalEntryNotFoundError as e:
  388. resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
  389. if (
  390. resolved_file is not None
  391. or not _raise_exceptions_for_missing_entries
  392. or not _raise_exceptions_for_connection_errors
  393. ):
  394. return resolved_file
  395. raise EnvironmentError(
  396. f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
  397. f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
  398. f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
  399. " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
  400. ) from e
  401. except EntryNotFoundError as e:
  402. if not _raise_exceptions_for_missing_entries:
  403. return None
  404. if revision is None:
  405. revision = "main"
  406. if filename in ["config.json", f"{subfolder}/config.json"]:
  407. return None
  408. raise EnvironmentError(
  409. f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
  410. f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
  411. ) from e
  412. except HTTPError as err:
  413. resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
  414. if resolved_file is not None or not _raise_exceptions_for_connection_errors:
  415. return resolved_file
  416. raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
  417. except HFValidationError as e:
  418. raise EnvironmentError(
  419. f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
  420. ) from e
  421. return resolved_file
  422. # TODO: deprecate `get_file_from_repo` or document it differently?
  423. # Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
  424. # there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
  425. # IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
  426. def get_file_from_repo(
  427. path_or_repo: Union[str, os.PathLike],
  428. filename: str,
  429. cache_dir: Optional[Union[str, os.PathLike]] = None,
  430. force_download: bool = False,
  431. resume_download: Optional[bool] = None,
  432. proxies: Optional[Dict[str, str]] = None,
  433. token: Optional[Union[bool, str]] = None,
  434. revision: Optional[str] = None,
  435. local_files_only: bool = False,
  436. subfolder: str = "",
  437. **deprecated_kwargs,
  438. ):
  439. """
  440. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
  441. Args:
  442. path_or_repo (`str` or `os.PathLike`):
  443. This can be either:
  444. - a string, the *model id* of a model repo on huggingface.co.
  445. - a path to a *directory* potentially containing the file.
  446. filename (`str`):
  447. The name of the file to locate in `path_or_repo`.
  448. cache_dir (`str` or `os.PathLike`, *optional*):
  449. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  450. cache should not be used.
  451. force_download (`bool`, *optional*, defaults to `False`):
  452. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  453. exist.
  454. resume_download:
  455. Deprecated and ignored. All downloads are now resumed by default when possible.
  456. Will be removed in v5 of Transformers.
  457. proxies (`Dict[str, str]`, *optional*):
  458. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  459. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  460. token (`str` or *bool*, *optional*):
  461. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  462. when running `huggingface-cli login` (stored in `~/.huggingface`).
  463. revision (`str`, *optional*, defaults to `"main"`):
  464. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  465. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  466. identifier allowed by git.
  467. local_files_only (`bool`, *optional*, defaults to `False`):
  468. If `True`, will only try to load the tokenizer configuration from local files.
  469. subfolder (`str`, *optional*, defaults to `""`):
  470. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  471. specify the folder name here.
  472. <Tip>
  473. Passing `token=True` is required when you want to use a private model.
  474. </Tip>
  475. Returns:
  476. `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
  477. file does not exist.
  478. Examples:
  479. ```python
  480. # Download a tokenizer configuration from huggingface.co and cache.
  481. tokenizer_config = get_file_from_repo("google-bert/bert-base-uncased", "tokenizer_config.json")
  482. # This model does not have a tokenizer config so the result will be None.
  483. tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
  484. ```
  485. """
  486. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  487. if use_auth_token is not None:
  488. warnings.warn(
  489. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  490. FutureWarning,
  491. )
  492. if token is not None:
  493. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  494. token = use_auth_token
  495. return cached_file(
  496. path_or_repo_id=path_or_repo,
  497. filename=filename,
  498. cache_dir=cache_dir,
  499. force_download=force_download,
  500. resume_download=resume_download,
  501. proxies=proxies,
  502. token=token,
  503. revision=revision,
  504. local_files_only=local_files_only,
  505. subfolder=subfolder,
  506. _raise_exceptions_for_gated_repo=False,
  507. _raise_exceptions_for_missing_entries=False,
  508. _raise_exceptions_for_connection_errors=False,
  509. )
  510. def download_url(url, proxies=None):
  511. """
  512. Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
  513. for deprecated behavior allowing to download config/models with a single url instead of using the Hub.
  514. Args:
  515. url (`str`): The url of the file to download.
  516. proxies (`Dict[str, str]`, *optional*):
  517. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  518. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  519. Returns:
  520. `str`: The location of the temporary file where the url was downloaded.
  521. """
  522. warnings.warn(
  523. f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
  524. " v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
  525. " that this is not compatible with the caching system (your file will be downloaded at each execution) or"
  526. " multiple processes (each process will download the file in a different temporary file).",
  527. FutureWarning,
  528. )
  529. tmp_fd, tmp_file = tempfile.mkstemp()
  530. with os.fdopen(tmp_fd, "wb") as f:
  531. http_get(url, f, proxies=proxies)
  532. return tmp_file
  533. def has_file(
  534. path_or_repo: Union[str, os.PathLike],
  535. filename: str,
  536. revision: Optional[str] = None,
  537. proxies: Optional[Dict[str, str]] = None,
  538. token: Optional[Union[bool, str]] = None,
  539. *,
  540. local_files_only: bool = False,
  541. cache_dir: Union[str, Path, None] = None,
  542. repo_type: Optional[str] = None,
  543. **deprecated_kwargs,
  544. ):
  545. """
  546. Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
  547. If offline mode is enabled, checks if the file exists in the cache.
  548. <Tip warning={false}>
  549. This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
  550. this repo, but will return False for regular connection errors.
  551. </Tip>
  552. """
  553. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  554. if use_auth_token is not None:
  555. warnings.warn(
  556. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  557. FutureWarning,
  558. )
  559. if token is not None:
  560. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  561. token = use_auth_token
  562. # If path to local directory, check if the file exists
  563. if os.path.isdir(path_or_repo):
  564. return os.path.isfile(os.path.join(path_or_repo, filename))
  565. # Else it's a repo => let's check if the file exists in local cache or on the Hub
  566. # Check if file exists in cache
  567. # This information might be outdated so it's best to also make a HEAD call (if allowed).
  568. cached_path = try_to_load_from_cache(
  569. repo_id=path_or_repo,
  570. filename=filename,
  571. revision=revision,
  572. repo_type=repo_type,
  573. cache_dir=cache_dir,
  574. )
  575. has_file_in_cache = isinstance(cached_path, str)
  576. # If local_files_only, don't try the HEAD call
  577. if local_files_only:
  578. return has_file_in_cache
  579. # Check if the file exists
  580. try:
  581. response = get_session().head(
  582. hf_hub_url(path_or_repo, filename=filename, revision=revision, repo_type=repo_type),
  583. headers=build_hf_headers(token=token, user_agent=http_user_agent()),
  584. allow_redirects=False,
  585. proxies=proxies,
  586. timeout=10,
  587. )
  588. except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
  589. # Actually raise for those subclasses of ConnectionError
  590. raise
  591. except (
  592. requests.exceptions.ConnectionError,
  593. requests.exceptions.Timeout,
  594. OfflineModeIsEnabled,
  595. ):
  596. return has_file_in_cache
  597. try:
  598. hf_raise_for_status(response)
  599. return True
  600. except GatedRepoError as e:
  601. logger.error(e)
  602. raise EnvironmentError(
  603. f"{path_or_repo} is a gated repository. Make sure to request access at "
  604. f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
  605. "logging in with `huggingface-cli login` or by passing `token=<your_token>`."
  606. ) from e
  607. except RepositoryNotFoundError as e:
  608. logger.error(e)
  609. raise EnvironmentError(
  610. f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
  611. ) from e
  612. except RevisionNotFoundError as e:
  613. logger.error(e)
  614. raise EnvironmentError(
  615. f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
  616. f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
  617. ) from e
  618. except EntryNotFoundError:
  619. return False # File does not exist
  620. except requests.HTTPError:
  621. # Any authentication/authorization error will be caught here => default to cache
  622. return has_file_in_cache
  623. class PushToHubMixin:
  624. """
  625. A Mixin containing the functionality to push a model or tokenizer to the hub.
  626. """
  627. def _create_repo(
  628. self,
  629. repo_id: str,
  630. private: Optional[bool] = None,
  631. token: Optional[Union[bool, str]] = None,
  632. repo_url: Optional[str] = None,
  633. organization: Optional[str] = None,
  634. ) -> str:
  635. """
  636. Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
  637. the token.
  638. """
  639. if repo_url is not None:
  640. warnings.warn(
  641. "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
  642. "instead."
  643. )
  644. if repo_id is not None:
  645. raise ValueError(
  646. "`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
  647. )
  648. repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
  649. if organization is not None:
  650. warnings.warn(
  651. "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
  652. "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
  653. )
  654. if not repo_id.startswith(organization):
  655. if "/" in repo_id:
  656. repo_id = repo_id.split("/")[-1]
  657. repo_id = f"{organization}/{repo_id}"
  658. url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
  659. return url.repo_id
  660. def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
  661. """
  662. Returns the list of files with their last modification timestamp.
  663. """
  664. return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}
  665. def _upload_modified_files(
  666. self,
  667. working_dir: Union[str, os.PathLike],
  668. repo_id: str,
  669. files_timestamps: Dict[str, float],
  670. commit_message: Optional[str] = None,
  671. token: Optional[Union[bool, str]] = None,
  672. create_pr: bool = False,
  673. revision: str = None,
  674. commit_description: str = None,
  675. ):
  676. """
  677. Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
  678. """
  679. if commit_message is None:
  680. if "Model" in self.__class__.__name__:
  681. commit_message = "Upload model"
  682. elif "Config" in self.__class__.__name__:
  683. commit_message = "Upload config"
  684. elif "Tokenizer" in self.__class__.__name__:
  685. commit_message = "Upload tokenizer"
  686. elif "FeatureExtractor" in self.__class__.__name__:
  687. commit_message = "Upload feature extractor"
  688. elif "Processor" in self.__class__.__name__:
  689. commit_message = "Upload processor"
  690. else:
  691. commit_message = f"Upload {self.__class__.__name__}"
  692. modified_files = [
  693. f
  694. for f in os.listdir(working_dir)
  695. if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
  696. ]
  697. # filter for actual files + folders at the root level
  698. modified_files = [
  699. f
  700. for f in modified_files
  701. if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
  702. ]
  703. operations = []
  704. # upload standalone files
  705. for file in modified_files:
  706. if os.path.isdir(os.path.join(working_dir, file)):
  707. # go over individual files of folder
  708. for f in os.listdir(os.path.join(working_dir, file)):
  709. operations.append(
  710. CommitOperationAdd(
  711. path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
  712. )
  713. )
  714. else:
  715. operations.append(
  716. CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
  717. )
  718. if revision is not None and not revision.startswith("refs/pr"):
  719. try:
  720. create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
  721. except HfHubHTTPError as e:
  722. if e.response.status_code == 403 and create_pr:
  723. # If we are creating a PR on a repo we don't have access to, we can't create the branch.
  724. # so let's assume the branch already exists. If it's not the case, an error will be raised when
  725. # calling `create_commit` below.
  726. pass
  727. else:
  728. raise
  729. logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
  730. return create_commit(
  731. repo_id=repo_id,
  732. operations=operations,
  733. commit_message=commit_message,
  734. commit_description=commit_description,
  735. token=token,
  736. create_pr=create_pr,
  737. revision=revision,
  738. )
  739. def push_to_hub(
  740. self,
  741. repo_id: str,
  742. use_temp_dir: Optional[bool] = None,
  743. commit_message: Optional[str] = None,
  744. private: Optional[bool] = None,
  745. token: Optional[Union[bool, str]] = None,
  746. max_shard_size: Optional[Union[int, str]] = "5GB",
  747. create_pr: bool = False,
  748. safe_serialization: bool = True,
  749. revision: str = None,
  750. commit_description: str = None,
  751. tags: Optional[List[str]] = None,
  752. **deprecated_kwargs,
  753. ) -> str:
  754. """
  755. Upload the {object_files} to the 🤗 Model Hub.
  756. Parameters:
  757. repo_id (`str`):
  758. The name of the repository you want to push your {object} to. It should contain your organization name
  759. when pushing to a given organization.
  760. use_temp_dir (`bool`, *optional*):
  761. Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
  762. Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
  763. commit_message (`str`, *optional*):
  764. Message to commit while pushing. Will default to `"Upload {object}"`.
  765. private (`bool`, *optional*):
  766. Whether or not the repository created should be private.
  767. token (`bool` or `str`, *optional*):
  768. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  769. when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
  770. is not specified.
  771. max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
  772. Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
  773. will then be each of size lower than this size. If expressed as a string, needs to be digits followed
  774. by a unit (like `"5MB"`). We default it to `"5GB"` so that users can easily load models on free-tier
  775. Google Colab instances without any CPU OOM issues.
  776. create_pr (`bool`, *optional*, defaults to `False`):
  777. Whether or not to create a PR with the uploaded files or directly commit.
  778. safe_serialization (`bool`, *optional*, defaults to `True`):
  779. Whether or not to convert the model weights in safetensors format for safer serialization.
  780. revision (`str`, *optional*):
  781. Branch to push the uploaded files to.
  782. commit_description (`str`, *optional*):
  783. The description of the commit that will be created
  784. tags (`List[str]`, *optional*):
  785. List of tags to push on the Hub.
  786. Examples:
  787. ```python
  788. from transformers import {object_class}
  789. {object} = {object_class}.from_pretrained("google-bert/bert-base-cased")
  790. # Push the {object} to your namespace with the name "my-finetuned-bert".
  791. {object}.push_to_hub("my-finetuned-bert")
  792. # Push the {object} to an organization with the name "my-finetuned-bert".
  793. {object}.push_to_hub("huggingface/my-finetuned-bert")
  794. ```
  795. """
  796. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  797. ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
  798. if use_auth_token is not None:
  799. warnings.warn(
  800. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  801. FutureWarning,
  802. )
  803. if token is not None:
  804. raise ValueError(
  805. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  806. )
  807. token = use_auth_token
  808. repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
  809. if repo_path_or_name is not None:
  810. # Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
  811. # repo_id from the folder path, if it exists.
  812. warnings.warn(
  813. "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
  814. "`repo_id` instead.",
  815. FutureWarning,
  816. )
  817. if repo_id is not None:
  818. raise ValueError(
  819. "`repo_id` and `repo_path_or_name` are both specified. Please set only the argument `repo_id`."
  820. )
  821. if os.path.isdir(repo_path_or_name):
  822. # repo_path: infer repo_id from the path
  823. repo_id = repo_id.split(os.path.sep)[-1]
  824. working_dir = repo_id
  825. else:
  826. # repo_name: use it as repo_id
  827. repo_id = repo_path_or_name
  828. working_dir = repo_id.split("/")[-1]
  829. else:
  830. # Repo_id is passed correctly: infer working_dir from it
  831. working_dir = repo_id.split("/")[-1]
  832. # Deprecation warning will be sent after for repo_url and organization
  833. repo_url = deprecated_kwargs.pop("repo_url", None)
  834. organization = deprecated_kwargs.pop("organization", None)
  835. repo_id = self._create_repo(
  836. repo_id, private=private, token=token, repo_url=repo_url, organization=organization
  837. )
  838. # Create a new empty model card and eventually tag it
  839. model_card = create_and_tag_model_card(
  840. repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
  841. )
  842. if use_temp_dir is None:
  843. use_temp_dir = not os.path.isdir(working_dir)
  844. with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
  845. files_timestamps = self._get_files_timestamps(work_dir)
  846. # Save all files.
  847. self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
  848. # Update model card if needed:
  849. model_card.save(os.path.join(work_dir, "README.md"))
  850. return self._upload_modified_files(
  851. work_dir,
  852. repo_id,
  853. files_timestamps,
  854. commit_message=commit_message,
  855. token=token,
  856. create_pr=create_pr,
  857. revision=revision,
  858. commit_description=commit_description,
  859. )
  860. def send_example_telemetry(example_name, *example_args, framework="pytorch"):
  861. """
  862. Sends telemetry that helps tracking the examples use.
  863. Args:
  864. example_name (`str`): The name of the example.
  865. *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
  866. try to extract the model and dataset name from those. Nothing else is tracked.
  867. framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
  868. """
  869. if is_offline_mode():
  870. return
  871. data = {"example": example_name, "framework": framework}
  872. for args in example_args:
  873. args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
  874. if "model_name_or_path" in args_as_dict:
  875. model_name = args_as_dict["model_name_or_path"]
  876. # Filter out local paths
  877. if not os.path.isdir(model_name):
  878. data["model_name"] = args_as_dict["model_name_or_path"]
  879. if "dataset_name" in args_as_dict:
  880. data["dataset_name"] = args_as_dict["dataset_name"]
  881. elif "task_name" in args_as_dict:
  882. # Extract script name from the example_name
  883. script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
  884. script_name = script_name.replace("_no_trainer", "")
  885. data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
  886. # Send telemetry in the background
  887. send_telemetry(
  888. topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
  889. )
  890. def convert_file_size_to_int(size: Union[int, str]):
  891. """
  892. Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
  893. Args:
  894. size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
  895. Example:
  896. ```py
  897. >>> convert_file_size_to_int("1MiB")
  898. 1048576
  899. ```
  900. """
  901. if isinstance(size, int):
  902. return size
  903. if size.upper().endswith("GIB"):
  904. return int(size[:-3]) * (2**30)
  905. if size.upper().endswith("MIB"):
  906. return int(size[:-3]) * (2**20)
  907. if size.upper().endswith("KIB"):
  908. return int(size[:-3]) * (2**10)
  909. if size.upper().endswith("GB"):
  910. int_size = int(size[:-2]) * (10**9)
  911. return int_size // 8 if size.endswith("b") else int_size
  912. if size.upper().endswith("MB"):
  913. int_size = int(size[:-2]) * (10**6)
  914. return int_size // 8 if size.endswith("b") else int_size
  915. if size.upper().endswith("KB"):
  916. int_size = int(size[:-2]) * (10**3)
  917. return int_size // 8 if size.endswith("b") else int_size
  918. raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
  919. def get_checkpoint_shard_files(
  920. pretrained_model_name_or_path,
  921. index_filename,
  922. cache_dir=None,
  923. force_download=False,
  924. proxies=None,
  925. resume_download=None,
  926. local_files_only=False,
  927. token=None,
  928. user_agent=None,
  929. revision=None,
  930. subfolder="",
  931. _commit_hash=None,
  932. **deprecated_kwargs,
  933. ):
  934. """
  935. For a given model:
  936. - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
  937. Hub
  938. - returns the list of paths to all the shards, as well as some metadata.
  939. For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
  940. index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
  941. """
  942. import json
  943. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  944. if use_auth_token is not None:
  945. warnings.warn(
  946. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  947. FutureWarning,
  948. )
  949. if token is not None:
  950. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  951. token = use_auth_token
  952. if not os.path.isfile(index_filename):
  953. raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
  954. with open(index_filename, "r") as f:
  955. index = json.loads(f.read())
  956. shard_filenames = sorted(set(index["weight_map"].values()))
  957. sharded_metadata = index["metadata"]
  958. sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
  959. sharded_metadata["weight_map"] = index["weight_map"].copy()
  960. # First, let's deal with local folder.
  961. if os.path.isdir(pretrained_model_name_or_path):
  962. shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
  963. return shard_filenames, sharded_metadata
  964. # At this stage pretrained_model_name_or_path is a model identifier on the Hub
  965. cached_filenames = []
  966. # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
  967. # downloaded (if interrupted).
  968. last_shard = try_to_load_from_cache(
  969. pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
  970. )
  971. show_progress_bar = last_shard is None or force_download
  972. for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
  973. try:
  974. # Load from URL
  975. cached_filename = cached_file(
  976. pretrained_model_name_or_path,
  977. shard_filename,
  978. cache_dir=cache_dir,
  979. force_download=force_download,
  980. proxies=proxies,
  981. resume_download=resume_download,
  982. local_files_only=local_files_only,
  983. token=token,
  984. user_agent=user_agent,
  985. revision=revision,
  986. subfolder=subfolder,
  987. _commit_hash=_commit_hash,
  988. )
  989. # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
  990. # we don't have to catch them here.
  991. except EntryNotFoundError:
  992. raise EnvironmentError(
  993. f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
  994. "required according to the checkpoint index."
  995. )
  996. except HTTPError:
  997. raise EnvironmentError(
  998. f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
  999. " again after checking your internet connection."
  1000. )
  1001. cached_filenames.append(cached_filename)
  1002. return cached_filenames, sharded_metadata
  1003. # All what is below is for conversion between old cache format and new cache format.
  1004. def get_all_cached_files(cache_dir=None):
  1005. """
  1006. Returns a list for all files cached with appropriate metadata.
  1007. """
  1008. if cache_dir is None:
  1009. cache_dir = TRANSFORMERS_CACHE
  1010. else:
  1011. cache_dir = str(cache_dir)
  1012. if not os.path.isdir(cache_dir):
  1013. return []
  1014. cached_files = []
  1015. for file in os.listdir(cache_dir):
  1016. meta_path = os.path.join(cache_dir, f"{file}.json")
  1017. if not os.path.isfile(meta_path):
  1018. continue
  1019. with open(meta_path, encoding="utf-8") as meta_file:
  1020. metadata = json.load(meta_file)
  1021. url = metadata["url"]
  1022. etag = metadata["etag"].replace('"', "")
  1023. cached_files.append({"file": file, "url": url, "etag": etag})
  1024. return cached_files
  1025. def extract_info_from_url(url):
  1026. """
  1027. Extract repo_name, revision and filename from an url.
  1028. """
  1029. search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
  1030. if search is None:
  1031. return None
  1032. repo, revision, filename = search.groups()
  1033. cache_repo = "--".join(["models"] + repo.split("/"))
  1034. return {"repo": cache_repo, "revision": revision, "filename": filename}
  1035. def create_and_tag_model_card(
  1036. repo_id: str,
  1037. tags: Optional[List[str]] = None,
  1038. token: Optional[str] = None,
  1039. ignore_metadata_errors: bool = False,
  1040. ):
  1041. """
  1042. Creates or loads an existing model card and tags it.
  1043. Args:
  1044. repo_id (`str`):
  1045. The repo_id where to look for the model card.
  1046. tags (`List[str]`, *optional*):
  1047. The list of tags to add in the model card
  1048. token (`str`, *optional*):
  1049. Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
  1050. ignore_metadata_errors (`str`):
  1051. If True, errors while parsing the metadata section will be ignored. Some information might be lost during
  1052. the process. Use it at your own risk.
  1053. """
  1054. try:
  1055. # Check if the model card is present on the remote repo
  1056. model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
  1057. except EntryNotFoundError:
  1058. # Otherwise create a simple model card from template
  1059. model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
  1060. card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
  1061. model_card = ModelCard.from_template(card_data, model_description=model_description)
  1062. if tags is not None:
  1063. # Ensure model_card.data.tags is a list and not None
  1064. if model_card.data.tags is None:
  1065. model_card.data.tags = []
  1066. for model_tag in tags:
  1067. if model_tag not in model_card.data.tags:
  1068. model_card.data.tags.append(model_tag)
  1069. return model_card
  1070. def clean_files_for(file):
  1071. """
  1072. Remove, if they exist, file, file.json and file.lock
  1073. """
  1074. for f in [file, f"{file}.json", f"{file}.lock"]:
  1075. if os.path.isfile(f):
  1076. os.remove(f)
  1077. def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
  1078. """
  1079. Move file to repo following the new huggingface hub cache organization.
  1080. """
  1081. os.makedirs(repo, exist_ok=True)
  1082. # refs
  1083. os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
  1084. if revision != commit_hash:
  1085. ref_path = os.path.join(repo, "refs", revision)
  1086. with open(ref_path, "w") as f:
  1087. f.write(commit_hash)
  1088. # blobs
  1089. os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
  1090. blob_path = os.path.join(repo, "blobs", etag)
  1091. shutil.move(file, blob_path)
  1092. # snapshots
  1093. os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
  1094. os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
  1095. pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
  1096. huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
  1097. clean_files_for(file)
  1098. def move_cache(cache_dir=None, new_cache_dir=None, token=None):
  1099. if new_cache_dir is None:
  1100. new_cache_dir = TRANSFORMERS_CACHE
  1101. if cache_dir is None:
  1102. # Migrate from old cache in .cache/huggingface/transformers
  1103. old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
  1104. if os.path.isdir(str(old_cache)):
  1105. cache_dir = str(old_cache)
  1106. else:
  1107. cache_dir = new_cache_dir
  1108. cached_files = get_all_cached_files(cache_dir=cache_dir)
  1109. logger.info(f"Moving {len(cached_files)} files to the new cache system")
  1110. hub_metadata = {}
  1111. for file_info in tqdm(cached_files):
  1112. url = file_info.pop("url")
  1113. if url not in hub_metadata:
  1114. try:
  1115. hub_metadata[url] = get_hf_file_metadata(url, token=token)
  1116. except requests.HTTPError:
  1117. continue
  1118. etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
  1119. if etag is None or commit_hash is None:
  1120. continue
  1121. if file_info["etag"] != etag:
  1122. # Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
  1123. clean_files_for(os.path.join(cache_dir, file_info["file"]))
  1124. continue
  1125. url_info = extract_info_from_url(url)
  1126. if url_info is None:
  1127. # Not a file from huggingface.co
  1128. continue
  1129. repo = os.path.join(new_cache_dir, url_info["repo"])
  1130. move_to_new_cache(
  1131. file=os.path.join(cache_dir, file_info["file"]),
  1132. repo=repo,
  1133. filename=url_info["filename"],
  1134. revision=url_info["revision"],
  1135. etag=etag,
  1136. commit_hash=commit_hash,
  1137. )
  1138. class PushInProgress:
  1139. """
  1140. Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
  1141. """
  1142. def __init__(self, jobs: Optional[futures.Future] = None) -> None:
  1143. self.jobs = [] if jobs is None else jobs
  1144. def is_done(self):
  1145. return all(job.done() for job in self.jobs)
  1146. def wait_until_done(self):
  1147. futures.wait(self.jobs)
  1148. def cancel(self) -> None:
  1149. self.jobs = [
  1150. job
  1151. for job in self.jobs
  1152. # Cancel the job if it wasn't started yet and remove cancelled/done jobs from the list
  1153. if not (job.cancel() or job.done())
  1154. ]
  1155. cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
  1156. if not os.path.isfile(cache_version_file):
  1157. cache_version = 0
  1158. else:
  1159. with open(cache_version_file) as f:
  1160. try:
  1161. cache_version = int(f.read())
  1162. except ValueError:
  1163. cache_version = 0
  1164. cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0
  1165. if cache_version < 1 and cache_is_not_empty:
  1166. if is_offline_mode():
  1167. logger.warning(
  1168. "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
  1169. "cache seems to be the one of a previous version. It is very likely that all your calls to any "
  1170. "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
  1171. "your cache be updated automatically, then you can go back to offline mode."
  1172. )
  1173. else:
  1174. logger.warning(
  1175. "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
  1176. "one-time only operation. You can interrupt this and resume the migration later on by calling "
  1177. "`transformers.utils.move_cache()`."
  1178. )
  1179. try:
  1180. if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
  1181. # Users set some env variable to customize cache storage
  1182. move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
  1183. else:
  1184. move_cache()
  1185. except Exception as e:
  1186. trace = "\n".join(traceback.format_tb(e.__traceback__))
  1187. logger.error(
  1188. f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
  1189. "file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
  1190. "message and we will do our best to help."
  1191. )
  1192. if cache_version < 1:
  1193. try:
  1194. os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
  1195. with open(cache_version_file, "w") as f:
  1196. f.write("1")
  1197. except Exception:
  1198. logger.warning(
  1199. f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
  1200. "the environment variable TRANSFORMERS_CACHE to a writable directory."
  1201. )