modelcard.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Configuration base class and utilities."""
  16. import copy
  17. import json
  18. import os
  19. import warnings
  20. from dataclasses import dataclass
  21. from pathlib import Path
  22. from typing import Any, Dict, List, Optional, Union
  23. import requests
  24. import yaml
  25. from huggingface_hub import model_info
  26. from huggingface_hub.utils import HFValidationError
  27. from . import __version__
  28. from .models.auto.modeling_auto import (
  29. MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  30. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  31. MODEL_FOR_CTC_MAPPING_NAMES,
  32. MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  33. MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
  34. MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  35. MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
  36. MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  37. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  38. MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  39. MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
  40. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
  41. MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  42. MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  43. )
  44. from .training_args import ParallelMode
  45. from .utils import (
  46. MODEL_CARD_NAME,
  47. cached_file,
  48. is_datasets_available,
  49. is_offline_mode,
  50. is_tf_available,
  51. is_tokenizers_available,
  52. is_torch_available,
  53. logging,
  54. )
  55. TASK_MAPPING = {
  56. "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  57. "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  58. "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
  59. "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
  60. "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
  61. "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
  62. "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  63. "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
  64. "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
  65. "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
  66. "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
  67. "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
  68. "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
  69. }
  70. logger = logging.get_logger(__name__)
  71. class ModelCard:
  72. r"""
  73. Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
  74. Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
  75. Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
  76. Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993
  77. Note: A model card can be loaded and saved to disk.
  78. """
  79. def __init__(self, **kwargs):
  80. warnings.warn(
  81. "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
  82. )
  83. # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
  84. self.model_details = kwargs.pop("model_details", {})
  85. self.intended_use = kwargs.pop("intended_use", {})
  86. self.factors = kwargs.pop("factors", {})
  87. self.metrics = kwargs.pop("metrics", {})
  88. self.evaluation_data = kwargs.pop("evaluation_data", {})
  89. self.training_data = kwargs.pop("training_data", {})
  90. self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
  91. self.ethical_considerations = kwargs.pop("ethical_considerations", {})
  92. self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
  93. # Open additional attributes
  94. for key, value in kwargs.items():
  95. try:
  96. setattr(self, key, value)
  97. except AttributeError as err:
  98. logger.error(f"Can't set {key} with value {value} for {self}")
  99. raise err
  100. def save_pretrained(self, save_directory_or_file):
  101. """Save a model card object to the directory or file `save_directory_or_file`."""
  102. if os.path.isdir(save_directory_or_file):
  103. # If we save using the predefined names, we can load using `from_pretrained`
  104. output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
  105. else:
  106. output_model_card_file = save_directory_or_file
  107. self.to_json_file(output_model_card_file)
  108. logger.info(f"Model card saved in {output_model_card_file}")
  109. @classmethod
  110. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  111. r"""
  112. Instantiate a [`ModelCard`] from a pre-trained model model card.
  113. Parameters:
  114. pretrained_model_name_or_path: either:
  115. - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co.
  116. - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`]
  117. method, e.g.: `./my_model_directory/`.
  118. - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`.
  119. cache_dir: (*optional*) string:
  120. Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
  121. should not be used.
  122. kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading.
  123. - The values in kwargs of any keys which are model card attributes will be used to override the loaded
  124. values.
  125. - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
  126. *return_unused_kwargs* keyword parameter.
  127. proxies: (*optional*) dict, default None:
  128. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
  129. 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
  130. return_unused_kwargs: (*optional*) bool:
  131. - If False, then this function returns just the final model card object.
  132. - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a
  133. dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
  134. kwargs which has not been used to update *ModelCard* and is otherwise ignored.
  135. Examples:
  136. ```python
  137. # Download model card from huggingface.co and cache.
  138. modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased")
  139. # Model card was saved using *save_pretrained('./test/saved_model/')*
  140. modelcard = ModelCard.from_pretrained("./test/saved_model/")
  141. modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
  142. modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
  143. ```"""
  144. cache_dir = kwargs.pop("cache_dir", None)
  145. proxies = kwargs.pop("proxies", None)
  146. return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
  147. from_pipeline = kwargs.pop("_from_pipeline", None)
  148. user_agent = {"file_type": "model_card"}
  149. if from_pipeline is not None:
  150. user_agent["using_pipeline"] = from_pipeline
  151. is_local = os.path.isdir(pretrained_model_name_or_path)
  152. if os.path.isfile(pretrained_model_name_or_path):
  153. resolved_model_card_file = pretrained_model_name_or_path
  154. is_local = True
  155. else:
  156. try:
  157. # Load from URL or cache if already cached
  158. resolved_model_card_file = cached_file(
  159. pretrained_model_name_or_path,
  160. filename=MODEL_CARD_NAME,
  161. cache_dir=cache_dir,
  162. proxies=proxies,
  163. user_agent=user_agent,
  164. )
  165. if is_local:
  166. logger.info(f"loading model card file {resolved_model_card_file}")
  167. else:
  168. logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
  169. # Load model card
  170. modelcard = cls.from_json_file(resolved_model_card_file)
  171. except (EnvironmentError, json.JSONDecodeError):
  172. # We fall back on creating an empty model card
  173. modelcard = cls()
  174. # Update model card with kwargs if needed
  175. to_remove = []
  176. for key, value in kwargs.items():
  177. if hasattr(modelcard, key):
  178. setattr(modelcard, key, value)
  179. to_remove.append(key)
  180. for key in to_remove:
  181. kwargs.pop(key, None)
  182. logger.info(f"Model card: {modelcard}")
  183. if return_unused_kwargs:
  184. return modelcard, kwargs
  185. else:
  186. return modelcard
  187. @classmethod
  188. def from_dict(cls, json_object):
  189. """Constructs a `ModelCard` from a Python dictionary of parameters."""
  190. return cls(**json_object)
  191. @classmethod
  192. def from_json_file(cls, json_file):
  193. """Constructs a `ModelCard` from a json file of parameters."""
  194. with open(json_file, "r", encoding="utf-8") as reader:
  195. text = reader.read()
  196. dict_obj = json.loads(text)
  197. return cls(**dict_obj)
  198. def __eq__(self, other):
  199. return self.__dict__ == other.__dict__
  200. def __repr__(self):
  201. return str(self.to_json_string())
  202. def to_dict(self):
  203. """Serializes this instance to a Python dictionary."""
  204. output = copy.deepcopy(self.__dict__)
  205. return output
  206. def to_json_string(self):
  207. """Serializes this instance to a JSON string."""
  208. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
  209. def to_json_file(self, json_file_path):
  210. """Save this instance to a json file."""
  211. with open(json_file_path, "w", encoding="utf-8") as writer:
  212. writer.write(self.to_json_string())
  213. AUTOGENERATED_TRAINER_COMMENT = """
  214. <!-- This model card has been generated automatically according to the information the Trainer had access to. You
  215. should probably proofread and complete it, then remove this comment. -->
  216. """
  217. AUTOGENERATED_KERAS_COMMENT = """
  218. <!-- This model card has been generated automatically according to the information Keras had access to. You should
  219. probably proofread and complete it, then remove this comment. -->
  220. """
  221. TASK_TAG_TO_NAME_MAPPING = {
  222. "fill-mask": "Masked Language Modeling",
  223. "image-classification": "Image Classification",
  224. "image-segmentation": "Image Segmentation",
  225. "multiple-choice": "Multiple Choice",
  226. "object-detection": "Object Detection",
  227. "question-answering": "Question Answering",
  228. "summarization": "Summarization",
  229. "table-question-answering": "Table Question Answering",
  230. "text-classification": "Text Classification",
  231. "text-generation": "Causal Language Modeling",
  232. "text2text-generation": "Sequence-to-sequence Language Modeling",
  233. "token-classification": "Token Classification",
  234. "translation": "Translation",
  235. "zero-shot-classification": "Zero Shot Classification",
  236. "automatic-speech-recognition": "Automatic Speech Recognition",
  237. "audio-classification": "Audio Classification",
  238. }
  239. METRIC_TAGS = [
  240. "accuracy",
  241. "bleu",
  242. "f1",
  243. "matthews_correlation",
  244. "pearsonr",
  245. "precision",
  246. "recall",
  247. "rouge",
  248. "sacrebleu",
  249. "spearmanr",
  250. "wer",
  251. ]
  252. def _listify(obj):
  253. if obj is None:
  254. return []
  255. elif isinstance(obj, str):
  256. return [obj]
  257. else:
  258. return obj
  259. def _insert_values_as_list(metadata, name, values):
  260. if values is None:
  261. return metadata
  262. if isinstance(values, str):
  263. values = [values]
  264. values = [v for v in values if v is not None]
  265. if len(values) == 0:
  266. return metadata
  267. metadata[name] = values
  268. return metadata
  269. def infer_metric_tags_from_eval_results(eval_results):
  270. if eval_results is None:
  271. return {}
  272. result = {}
  273. for key in eval_results.keys():
  274. if key.lower().replace(" ", "_") in METRIC_TAGS:
  275. result[key.lower().replace(" ", "_")] = key
  276. elif key.lower() == "rouge1":
  277. result["rouge"] = key
  278. return result
  279. def _insert_value(metadata, name, value):
  280. if value is None:
  281. return metadata
  282. metadata[name] = value
  283. return metadata
  284. def is_hf_dataset(dataset):
  285. if not is_datasets_available():
  286. return False
  287. from datasets import Dataset, IterableDataset
  288. return isinstance(dataset, (Dataset, IterableDataset))
  289. def _get_mapping_values(mapping):
  290. result = []
  291. for v in mapping.values():
  292. if isinstance(v, (tuple, list)):
  293. result += list(v)
  294. else:
  295. result.append(v)
  296. return result
  297. @dataclass
  298. class TrainingSummary:
  299. model_name: str
  300. language: Optional[Union[str, List[str]]] = None
  301. license: Optional[str] = None
  302. tags: Optional[Union[str, List[str]]] = None
  303. finetuned_from: Optional[str] = None
  304. tasks: Optional[Union[str, List[str]]] = None
  305. dataset: Optional[Union[str, List[str]]] = None
  306. dataset_tags: Optional[Union[str, List[str]]] = None
  307. dataset_args: Optional[Union[str, List[str]]] = None
  308. dataset_metadata: Optional[Dict[str, Any]] = None
  309. eval_results: Optional[Dict[str, float]] = None
  310. eval_lines: Optional[List[str]] = None
  311. hyperparameters: Optional[Dict[str, Any]] = None
  312. source: Optional[str] = "trainer"
  313. def __post_init__(self):
  314. # Infer default license from the checkpoint used, if possible.
  315. if (
  316. self.license is None
  317. and not is_offline_mode()
  318. and self.finetuned_from is not None
  319. and len(self.finetuned_from) > 0
  320. ):
  321. try:
  322. info = model_info(self.finetuned_from)
  323. for tag in info.tags:
  324. if tag.startswith("license:"):
  325. self.license = tag[8:]
  326. except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
  327. pass
  328. def create_model_index(self, metric_mapping):
  329. model_index = {"name": self.model_name}
  330. # Dataset mapping tag -> name
  331. dataset_names = _listify(self.dataset)
  332. dataset_tags = _listify(self.dataset_tags)
  333. dataset_args = _listify(self.dataset_args)
  334. dataset_metadata = _listify(self.dataset_metadata)
  335. if len(dataset_args) < len(dataset_tags):
  336. dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
  337. dataset_mapping = dict(zip(dataset_tags, dataset_names))
  338. dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
  339. dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
  340. task_mapping = {
  341. task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
  342. }
  343. model_index["results"] = []
  344. if len(task_mapping) == 0 and len(dataset_mapping) == 0:
  345. return [model_index]
  346. if len(task_mapping) == 0:
  347. task_mapping = {None: None}
  348. if len(dataset_mapping) == 0:
  349. dataset_mapping = {None: None}
  350. # One entry per dataset and per task
  351. all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
  352. for task_tag, ds_tag in all_possibilities:
  353. result = {}
  354. if task_tag is not None:
  355. result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
  356. if ds_tag is not None:
  357. metadata = dataset_metadata_mapping.get(ds_tag, {})
  358. result["dataset"] = {
  359. "name": dataset_mapping[ds_tag],
  360. "type": ds_tag,
  361. **metadata,
  362. }
  363. if dataset_arg_mapping[ds_tag] is not None:
  364. result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
  365. if len(metric_mapping) > 0:
  366. result["metrics"] = []
  367. for metric_tag, metric_name in metric_mapping.items():
  368. result["metrics"].append(
  369. {
  370. "name": metric_name,
  371. "type": metric_tag,
  372. "value": self.eval_results[metric_name],
  373. }
  374. )
  375. # Remove partial results to avoid the model card being rejected.
  376. if "task" in result and "dataset" in result and "metrics" in result:
  377. model_index["results"].append(result)
  378. else:
  379. logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
  380. return [model_index]
  381. def create_metadata(self):
  382. metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
  383. metadata = {}
  384. metadata = _insert_value(metadata, "library_name", "transformers")
  385. metadata = _insert_values_as_list(metadata, "language", self.language)
  386. metadata = _insert_value(metadata, "license", self.license)
  387. if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
  388. metadata = _insert_value(metadata, "base_model", self.finetuned_from)
  389. metadata = _insert_values_as_list(metadata, "tags", self.tags)
  390. metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
  391. metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
  392. metadata["model-index"] = self.create_model_index(metric_mapping)
  393. return metadata
  394. def to_model_card(self):
  395. model_card = ""
  396. metadata = yaml.dump(self.create_metadata(), sort_keys=False)
  397. if len(metadata) > 0:
  398. model_card = f"---\n{metadata}---\n"
  399. # Now the model card for realsies.
  400. if self.source == "trainer":
  401. model_card += AUTOGENERATED_TRAINER_COMMENT
  402. else:
  403. model_card += AUTOGENERATED_KERAS_COMMENT
  404. model_card += f"\n# {self.model_name}\n\n"
  405. if self.finetuned_from is None:
  406. model_card += "This model was trained from scratch on "
  407. else:
  408. model_card += (
  409. "This model is a fine-tuned version of"
  410. f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
  411. )
  412. if self.dataset is None:
  413. model_card += "an unknown dataset."
  414. else:
  415. if isinstance(self.dataset, str):
  416. model_card += f"the {self.dataset} dataset."
  417. elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
  418. model_card += f"the {self.dataset[0]} dataset."
  419. else:
  420. model_card += (
  421. ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
  422. )
  423. if self.eval_results is not None:
  424. model_card += "\nIt achieves the following results on the evaluation set:\n"
  425. model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
  426. model_card += "\n"
  427. model_card += "\n## Model description\n\nMore information needed\n"
  428. model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
  429. model_card += "\n## Training and evaluation data\n\nMore information needed\n"
  430. model_card += "\n## Training procedure\n"
  431. model_card += "\n### Training hyperparameters\n"
  432. if self.hyperparameters is not None:
  433. model_card += "\nThe following hyperparameters were used during training:\n"
  434. model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
  435. model_card += "\n"
  436. else:
  437. model_card += "\nMore information needed\n"
  438. if self.eval_lines is not None:
  439. model_card += "\n### Training results\n\n"
  440. model_card += make_markdown_table(self.eval_lines)
  441. model_card += "\n"
  442. model_card += "\n### Framework versions\n\n"
  443. model_card += f"- Transformers {__version__}\n"
  444. if self.source == "trainer" and is_torch_available():
  445. import torch
  446. model_card += f"- Pytorch {torch.__version__}\n"
  447. elif self.source == "keras" and is_tf_available():
  448. import tensorflow as tf
  449. model_card += f"- TensorFlow {tf.__version__}\n"
  450. if is_datasets_available():
  451. import datasets
  452. model_card += f"- Datasets {datasets.__version__}\n"
  453. if is_tokenizers_available():
  454. import tokenizers
  455. model_card += f"- Tokenizers {tokenizers.__version__}\n"
  456. return model_card
  457. @classmethod
  458. def from_trainer(
  459. cls,
  460. trainer,
  461. language=None,
  462. license=None,
  463. tags=None,
  464. model_name=None,
  465. finetuned_from=None,
  466. tasks=None,
  467. dataset_tags=None,
  468. dataset_metadata=None,
  469. dataset=None,
  470. dataset_args=None,
  471. ):
  472. # Infer default from dataset
  473. one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
  474. if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
  475. default_tag = one_dataset.builder_name
  476. # Those are not real datasets from the Hub so we exclude them.
  477. if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
  478. if dataset_metadata is None:
  479. dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
  480. if dataset_tags is None:
  481. dataset_tags = [default_tag]
  482. if dataset_args is None:
  483. dataset_args = [one_dataset.config_name]
  484. if dataset is None and dataset_tags is not None:
  485. dataset = dataset_tags
  486. # Infer default finetuned_from
  487. if (
  488. finetuned_from is None
  489. and hasattr(trainer.model.config, "_name_or_path")
  490. and not os.path.isdir(trainer.model.config._name_or_path)
  491. ):
  492. finetuned_from = trainer.model.config._name_or_path
  493. # Infer default task tag:
  494. if tasks is None:
  495. model_class_name = trainer.model.__class__.__name__
  496. for task, mapping in TASK_MAPPING.items():
  497. if model_class_name in _get_mapping_values(mapping):
  498. tasks = task
  499. if model_name is None:
  500. model_name = Path(trainer.args.output_dir).name
  501. if len(model_name) == 0:
  502. model_name = finetuned_from
  503. # Add `generated_from_trainer` to the tags
  504. if tags is None:
  505. tags = ["generated_from_trainer"]
  506. elif isinstance(tags, str) and tags != "generated_from_trainer":
  507. tags = [tags, "generated_from_trainer"]
  508. elif "generated_from_trainer" not in tags:
  509. tags.append("generated_from_trainer")
  510. _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
  511. hyperparameters = extract_hyperparameters_from_trainer(trainer)
  512. return cls(
  513. language=language,
  514. license=license,
  515. tags=tags,
  516. model_name=model_name,
  517. finetuned_from=finetuned_from,
  518. tasks=tasks,
  519. dataset=dataset,
  520. dataset_tags=dataset_tags,
  521. dataset_args=dataset_args,
  522. dataset_metadata=dataset_metadata,
  523. eval_results=eval_results,
  524. eval_lines=eval_lines,
  525. hyperparameters=hyperparameters,
  526. )
  527. @classmethod
  528. def from_keras(
  529. cls,
  530. model,
  531. model_name,
  532. keras_history=None,
  533. language=None,
  534. license=None,
  535. tags=None,
  536. finetuned_from=None,
  537. tasks=None,
  538. dataset_tags=None,
  539. dataset=None,
  540. dataset_args=None,
  541. ):
  542. # Infer default from dataset
  543. if dataset is not None:
  544. if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
  545. default_tag = dataset.builder_name
  546. # Those are not real datasets from the Hub so we exclude them.
  547. if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
  548. if dataset_tags is None:
  549. dataset_tags = [default_tag]
  550. if dataset_args is None:
  551. dataset_args = [dataset.config_name]
  552. if dataset is None and dataset_tags is not None:
  553. dataset = dataset_tags
  554. # Infer default finetuned_from
  555. if (
  556. finetuned_from is None
  557. and hasattr(model.config, "_name_or_path")
  558. and not os.path.isdir(model.config._name_or_path)
  559. ):
  560. finetuned_from = model.config._name_or_path
  561. # Infer default task tag:
  562. if tasks is None:
  563. model_class_name = model.__class__.__name__
  564. for task, mapping in TASK_MAPPING.items():
  565. if model_class_name in _get_mapping_values(mapping):
  566. tasks = task
  567. # Add `generated_from_keras_callback` to the tags
  568. if tags is None:
  569. tags = ["generated_from_keras_callback"]
  570. elif isinstance(tags, str) and tags != "generated_from_keras_callback":
  571. tags = [tags, "generated_from_keras_callback"]
  572. elif "generated_from_keras_callback" not in tags:
  573. tags.append("generated_from_keras_callback")
  574. if keras_history is not None:
  575. _, eval_lines, eval_results = parse_keras_history(keras_history)
  576. else:
  577. eval_lines = []
  578. eval_results = {}
  579. hyperparameters = extract_hyperparameters_from_keras(model)
  580. return cls(
  581. language=language,
  582. license=license,
  583. tags=tags,
  584. model_name=model_name,
  585. finetuned_from=finetuned_from,
  586. tasks=tasks,
  587. dataset_tags=dataset_tags,
  588. dataset=dataset,
  589. dataset_args=dataset_args,
  590. eval_results=eval_results,
  591. eval_lines=eval_lines,
  592. hyperparameters=hyperparameters,
  593. source="keras",
  594. )
  595. def parse_keras_history(logs):
  596. """
  597. Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
  598. passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
  599. """
  600. if hasattr(logs, "history"):
  601. # This looks like a `History` object
  602. if not hasattr(logs, "epoch"):
  603. # This history looks empty, return empty results
  604. return None, [], {}
  605. logs.history["epoch"] = logs.epoch
  606. logs = logs.history
  607. else:
  608. # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
  609. logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
  610. lines = []
  611. for i in range(len(logs["epoch"])):
  612. epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
  613. values = {}
  614. for k, v in epoch_dict.items():
  615. if k.startswith("val_"):
  616. k = "validation_" + k[4:]
  617. elif k != "epoch":
  618. k = "train_" + k
  619. splits = k.split("_")
  620. name = " ".join([part.capitalize() for part in splits])
  621. values[name] = v
  622. lines.append(values)
  623. eval_results = lines[-1]
  624. return logs, lines, eval_results
  625. def parse_log_history(log_history):
  626. """
  627. Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
  628. """
  629. idx = 0
  630. while idx < len(log_history) and "train_runtime" not in log_history[idx]:
  631. idx += 1
  632. # If there are no training logs
  633. if idx == len(log_history):
  634. idx -= 1
  635. while idx >= 0 and "eval_loss" not in log_history[idx]:
  636. idx -= 1
  637. if idx >= 0:
  638. return None, None, log_history[idx]
  639. else:
  640. return None, None, None
  641. # From now one we can assume we have training logs:
  642. train_log = log_history[idx]
  643. lines = []
  644. training_loss = "No log"
  645. for i in range(idx):
  646. if "loss" in log_history[i]:
  647. training_loss = log_history[i]["loss"]
  648. if "eval_loss" in log_history[i]:
  649. metrics = log_history[i].copy()
  650. _ = metrics.pop("total_flos", None)
  651. epoch = metrics.pop("epoch", None)
  652. step = metrics.pop("step", None)
  653. _ = metrics.pop("eval_runtime", None)
  654. _ = metrics.pop("eval_samples_per_second", None)
  655. _ = metrics.pop("eval_steps_per_second", None)
  656. _ = metrics.pop("eval_jit_compilation_time", None)
  657. values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
  658. for k, v in metrics.items():
  659. if k == "eval_loss":
  660. values["Validation Loss"] = v
  661. else:
  662. splits = k.split("_")
  663. name = " ".join([part.capitalize() for part in splits[1:]])
  664. values[name] = v
  665. lines.append(values)
  666. idx = len(log_history) - 1
  667. while idx >= 0 and "eval_loss" not in log_history[idx]:
  668. idx -= 1
  669. if idx > 0:
  670. eval_results = {}
  671. for key, value in log_history[idx].items():
  672. if key.startswith("eval_"):
  673. key = key[5:]
  674. if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
  675. camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
  676. eval_results[camel_cased_key] = value
  677. return train_log, lines, eval_results
  678. else:
  679. return train_log, lines, None
  680. def extract_hyperparameters_from_keras(model):
  681. from .modeling_tf_utils import keras
  682. hyperparameters = {}
  683. if hasattr(model, "optimizer") and model.optimizer is not None:
  684. hyperparameters["optimizer"] = model.optimizer.get_config()
  685. else:
  686. hyperparameters["optimizer"] = None
  687. hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
  688. return hyperparameters
  689. def _maybe_round(v, decimals=4):
  690. if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
  691. return f"{v:.{decimals}f}"
  692. return str(v)
  693. def _regular_table_line(values, col_widths):
  694. values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
  695. return "".join(values_with_space) + "|\n"
  696. def _second_table_line(col_widths):
  697. values = ["|:" + "-" * w + ":" for w in col_widths]
  698. return "".join(values) + "|\n"
  699. def make_markdown_table(lines):
  700. """
  701. Create a nice Markdown table from the results in `lines`.
  702. """
  703. if lines is None or len(lines) == 0:
  704. return ""
  705. col_widths = {key: len(str(key)) for key in lines[0].keys()}
  706. for line in lines:
  707. for key, value in line.items():
  708. if col_widths[key] < len(_maybe_round(value)):
  709. col_widths[key] = len(_maybe_round(value))
  710. table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
  711. table += _second_table_line(list(col_widths.values()))
  712. for line in lines:
  713. table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
  714. return table
  715. _TRAINING_ARGS_KEYS = [
  716. "learning_rate",
  717. "train_batch_size",
  718. "eval_batch_size",
  719. "seed",
  720. ]
  721. def extract_hyperparameters_from_trainer(trainer):
  722. hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
  723. if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
  724. hyperparameters["distributed_type"] = (
  725. "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
  726. )
  727. if trainer.args.world_size > 1:
  728. hyperparameters["num_devices"] = trainer.args.world_size
  729. if trainer.args.gradient_accumulation_steps > 1:
  730. hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
  731. total_train_batch_size = (
  732. trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
  733. )
  734. if total_train_batch_size != hyperparameters["train_batch_size"]:
  735. hyperparameters["total_train_batch_size"] = total_train_batch_size
  736. total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
  737. if total_eval_batch_size != hyperparameters["eval_batch_size"]:
  738. hyperparameters["total_eval_batch_size"] = total_eval_batch_size
  739. if trainer.args.optim:
  740. optimizer_name = trainer.args.optim
  741. optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
  742. if "adam" in optimizer_name.lower():
  743. hyperparameters["optimizer"] = (
  744. f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
  745. f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
  746. )
  747. else:
  748. hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
  749. hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
  750. if trainer.args.warmup_ratio != 0.0:
  751. hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
  752. if trainer.args.warmup_steps != 0.0:
  753. hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
  754. if trainer.args.max_steps != -1:
  755. hyperparameters["training_steps"] = trainer.args.max_steps
  756. else:
  757. hyperparameters["num_epochs"] = trainer.args.num_train_epochs
  758. if trainer.args.fp16:
  759. if trainer.use_apex:
  760. hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
  761. else:
  762. hyperparameters["mixed_precision_training"] = "Native AMP"
  763. if trainer.args.label_smoothing_factor != 0.0:
  764. hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
  765. return hyperparameters