training_args_tf.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from dataclasses import dataclass, field
  16. from typing import Optional, Tuple
  17. from .training_args import TrainingArguments
  18. from .utils import cached_property, is_tf_available, logging, requires_backends
  19. logger = logging.get_logger(__name__)
  20. if is_tf_available():
  21. import tensorflow as tf
  22. from .modeling_tf_utils import keras
  23. @dataclass
  24. class TFTrainingArguments(TrainingArguments):
  25. """
  26. TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
  27. itself**.
  28. Using [`HfArgumentParser`] we can turn this class into
  29. [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
  30. command line.
  31. Parameters:
  32. output_dir (`str`):
  33. The output directory where the model predictions and checkpoints will be written.
  34. overwrite_output_dir (`bool`, *optional*, defaults to `False`):
  35. If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
  36. points to a checkpoint directory.
  37. do_train (`bool`, *optional*, defaults to `False`):
  38. Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
  39. by your training/evaluation scripts instead. See the [example
  40. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  41. do_eval (`bool`, *optional*):
  42. Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
  43. different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
  44. training/evaluation scripts instead. See the [example
  45. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  46. do_predict (`bool`, *optional*, defaults to `False`):
  47. Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
  48. intended to be used by your training/evaluation scripts instead. See the [example
  49. scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
  50. eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
  51. The evaluation strategy to adopt during training. Possible values are:
  52. - `"no"`: No evaluation is done during training.
  53. - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
  54. - `"epoch"`: Evaluation is done at the end of each epoch.
  55. per_device_train_batch_size (`int`, *optional*, defaults to 8):
  56. The batch size per GPU/TPU core/CPU for training.
  57. per_device_eval_batch_size (`int`, *optional*, defaults to 8):
  58. The batch size per GPU/TPU core/CPU for evaluation.
  59. gradient_accumulation_steps (`int`, *optional*, defaults to 1):
  60. Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
  61. <Tip warning={true}>
  62. When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
  63. evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
  64. </Tip>
  65. learning_rate (`float`, *optional*, defaults to 5e-5):
  66. The initial learning rate for Adam.
  67. weight_decay (`float`, *optional*, defaults to 0):
  68. The weight decay to apply (if not zero).
  69. adam_beta1 (`float`, *optional*, defaults to 0.9):
  70. The beta1 hyperparameter for the Adam optimizer.
  71. adam_beta2 (`float`, *optional*, defaults to 0.999):
  72. The beta2 hyperparameter for the Adam optimizer.
  73. adam_epsilon (`float`, *optional*, defaults to 1e-8):
  74. The epsilon hyperparameter for the Adam optimizer.
  75. max_grad_norm (`float`, *optional*, defaults to 1.0):
  76. Maximum gradient norm (for gradient clipping).
  77. num_train_epochs(`float`, *optional*, defaults to 3.0):
  78. Total number of training epochs to perform.
  79. max_steps (`int`, *optional*, defaults to -1):
  80. If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
  81. For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
  82. `max_steps` is reached.
  83. warmup_ratio (`float`, *optional*, defaults to 0.0):
  84. Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
  85. warmup_steps (`int`, *optional*, defaults to 0):
  86. Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
  87. logging_dir (`str`, *optional*):
  88. [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
  89. *runs/**CURRENT_DATETIME_HOSTNAME***.
  90. logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
  91. The logging strategy to adopt during training. Possible values are:
  92. - `"no"`: No logging is done during training.
  93. - `"epoch"`: Logging is done at the end of each epoch.
  94. - `"steps"`: Logging is done every `logging_steps`.
  95. logging_first_step (`bool`, *optional*, defaults to `False`):
  96. Whether to log and evaluate the first `global_step` or not.
  97. logging_steps (`int`, *optional*, defaults to 500):
  98. Number of update steps between two logs if `logging_strategy="steps"`.
  99. save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
  100. The checkpoint save strategy to adopt during training. Possible values are:
  101. - `"no"`: No save is done during training.
  102. - `"epoch"`: Save is done at the end of each epoch.
  103. - `"steps"`: Save is done every `save_steps`.
  104. save_steps (`int`, *optional*, defaults to 500):
  105. Number of updates steps before two checkpoint saves if `save_strategy="steps"`.
  106. save_total_limit (`int`, *optional*):
  107. If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
  108. `output_dir`.
  109. no_cuda (`bool`, *optional*, defaults to `False`):
  110. Whether to not use CUDA even when it is available or not.
  111. seed (`int`, *optional*, defaults to 42):
  112. Random seed that will be set at the beginning of training.
  113. fp16 (`bool`, *optional*, defaults to `False`):
  114. Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
  115. fp16_opt_level (`str`, *optional*, defaults to 'O1'):
  116. For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
  117. the [Apex documentation](https://nvidia.github.io/apex/amp).
  118. local_rank (`int`, *optional*, defaults to -1):
  119. During distributed training, the rank of the process.
  120. tpu_num_cores (`int`, *optional*):
  121. When training on TPU, the number of TPU cores (automatically passed by launcher script).
  122. debug (`bool`, *optional*, defaults to `False`):
  123. Whether to activate the trace to record computation graphs and profiling information or not.
  124. dataloader_drop_last (`bool`, *optional*, defaults to `False`):
  125. Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
  126. or not.
  127. eval_steps (`int`, *optional*, defaults to 1000):
  128. Number of update steps before two evaluations.
  129. past_index (`int`, *optional*, defaults to -1):
  130. Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make
  131. use of the past hidden states for their predictions. If this argument is set to a positive int, the
  132. `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at
  133. the next training step under the keyword argument `mems`.
  134. tpu_name (`str`, *optional*):
  135. The name of the TPU the process is running on.
  136. tpu_zone (`str`, *optional*):
  137. The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
  138. from metadata.
  139. gcp_project (`str`, *optional*):
  140. Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
  141. automatically detect from metadata.
  142. run_name (`str`, *optional*):
  143. A descriptor for the run. Notably used for wandb, mlflow and comet logging.
  144. xla (`bool`, *optional*):
  145. Whether to activate the XLA compilation or not.
  146. """
  147. framework = "tf"
  148. tpu_name: Optional[str] = field(
  149. default=None,
  150. metadata={"help": "Name of TPU"},
  151. )
  152. tpu_zone: Optional[str] = field(
  153. default=None,
  154. metadata={"help": "Zone of TPU"},
  155. )
  156. gcp_project: Optional[str] = field(
  157. default=None,
  158. metadata={"help": "Name of Cloud TPU-enabled project"},
  159. )
  160. poly_power: float = field(
  161. default=1.0,
  162. metadata={"help": "Power for the Polynomial decay LR scheduler."},
  163. )
  164. xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
  165. @cached_property
  166. def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
  167. requires_backends(self, ["tf"])
  168. logger.info("Tensorflow: setting up strategy")
  169. gpus = tf.config.list_physical_devices("GPU")
  170. # Set to float16 at first
  171. if self.fp16:
  172. keras.mixed_precision.set_global_policy("mixed_float16")
  173. if self.no_cuda:
  174. strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
  175. else:
  176. try:
  177. if self.tpu_name:
  178. tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
  179. self.tpu_name, zone=self.tpu_zone, project=self.gcp_project
  180. )
  181. else:
  182. tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
  183. except ValueError:
  184. if self.tpu_name:
  185. raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
  186. else:
  187. tpu = None
  188. if tpu:
  189. # Set to bfloat16 in case of TPU
  190. if self.fp16:
  191. keras.mixed_precision.set_global_policy("mixed_bfloat16")
  192. tf.config.experimental_connect_to_cluster(tpu)
  193. tf.tpu.experimental.initialize_tpu_system(tpu)
  194. strategy = tf.distribute.TPUStrategy(tpu)
  195. elif len(gpus) == 0:
  196. strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
  197. elif len(gpus) == 1:
  198. strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
  199. elif len(gpus) > 1:
  200. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
  201. strategy = tf.distribute.MirroredStrategy()
  202. else:
  203. raise ValueError("Cannot find the proper strategy, please check your environment properties.")
  204. return strategy
  205. @property
  206. def strategy(self) -> "tf.distribute.Strategy":
  207. """
  208. The strategy used for distributed training.
  209. """
  210. requires_backends(self, ["tf"])
  211. return self._setup_strategy
  212. @property
  213. def n_replicas(self) -> int:
  214. """
  215. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
  216. """
  217. requires_backends(self, ["tf"])
  218. return self._setup_strategy.num_replicas_in_sync
  219. @property
  220. def should_log(self):
  221. """
  222. Whether or not the current process should produce log.
  223. """
  224. return False # TF Logging is handled by Keras not the Trainer
  225. @property
  226. def train_batch_size(self) -> int:
  227. """
  228. The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
  229. """
  230. if self.per_gpu_train_batch_size:
  231. logger.warning(
  232. "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
  233. "version. Using `--per_device_train_batch_size` is preferred."
  234. )
  235. per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
  236. return per_device_batch_size * self.n_replicas
  237. @property
  238. def eval_batch_size(self) -> int:
  239. """
  240. The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
  241. """
  242. if self.per_gpu_eval_batch_size:
  243. logger.warning(
  244. "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
  245. "version. Using `--per_device_eval_batch_size` is preferred."
  246. )
  247. per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
  248. return per_device_batch_size * self.n_replicas
  249. @property
  250. def n_gpu(self) -> int:
  251. """
  252. The number of replicas (CPUs, GPUs or TPU cores) used in this training.
  253. """
  254. requires_backends(self, ["tf"])
  255. warnings.warn(
  256. "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
  257. FutureWarning,
  258. )
  259. return self._setup_strategy.num_replicas_in_sync