benchmark_args_tf.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from dataclasses import dataclass, field
  17. from typing import Tuple
  18. from ..utils import cached_property, is_tf_available, logging, requires_backends
  19. from .benchmark_args_utils import BenchmarkArguments
  20. if is_tf_available():
  21. import tensorflow as tf
  22. logger = logging.get_logger(__name__)
  23. @dataclass
  24. class TensorFlowBenchmarkArguments(BenchmarkArguments):
  25. deprecated_args = [
  26. "no_inference",
  27. "no_cuda",
  28. "no_tpu",
  29. "no_speed",
  30. "no_memory",
  31. "no_env_print",
  32. "no_multi_process",
  33. ]
  34. def __init__(self, **kwargs):
  35. """
  36. This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
  37. deleted
  38. """
  39. for deprecated_arg in self.deprecated_args:
  40. if deprecated_arg in kwargs:
  41. positive_arg = deprecated_arg[3:]
  42. kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
  43. logger.warning(
  44. f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
  45. f" {positive_arg}={kwargs[positive_arg]}"
  46. )
  47. self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
  48. self.device_idx = kwargs.pop("device_idx", self.device_idx)
  49. self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
  50. self.use_xla = kwargs.pop("use_xla", self.use_xla)
  51. super().__init__(**kwargs)
  52. tpu_name: str = field(
  53. default=None,
  54. metadata={"help": "Name of TPU"},
  55. )
  56. device_idx: int = field(
  57. default=0,
  58. metadata={"help": "CPU / GPU device index. Defaults to 0."},
  59. )
  60. eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
  61. use_xla: bool = field(
  62. default=False,
  63. metadata={
  64. "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
  65. },
  66. )
  67. @cached_property
  68. def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
  69. requires_backends(self, ["tf"])
  70. tpu = None
  71. if self.tpu:
  72. try:
  73. if self.tpu_name:
  74. tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
  75. else:
  76. tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
  77. except ValueError:
  78. tpu = None
  79. return tpu
  80. @cached_property
  81. def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
  82. requires_backends(self, ["tf"])
  83. if self.is_tpu:
  84. tf.config.experimental_connect_to_cluster(self._setup_tpu)
  85. tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
  86. strategy = tf.distribute.TPUStrategy(self._setup_tpu)
  87. else:
  88. # currently no multi gpu is allowed
  89. if self.is_gpu:
  90. # TODO: Currently only single GPU is supported
  91. tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
  92. strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
  93. else:
  94. tf.config.set_visible_devices([], "GPU") # disable GPU
  95. strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
  96. return strategy
  97. @property
  98. def is_tpu(self) -> bool:
  99. requires_backends(self, ["tf"])
  100. return self._setup_tpu is not None
  101. @property
  102. def strategy(self) -> "tf.distribute.Strategy":
  103. requires_backends(self, ["tf"])
  104. return self._setup_strategy
  105. @property
  106. def gpu_list(self):
  107. requires_backends(self, ["tf"])
  108. return tf.config.list_physical_devices("GPU")
  109. @property
  110. def n_gpu(self) -> int:
  111. requires_backends(self, ["tf"])
  112. if self.cuda:
  113. return len(self.gpu_list)
  114. return 0
  115. @property
  116. def is_gpu(self) -> bool:
  117. return self.n_gpu > 0