benchmark_args.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from dataclasses import dataclass, field
  17. from typing import Tuple
  18. from ..utils import (
  19. cached_property,
  20. is_torch_available,
  21. is_torch_xla_available,
  22. is_torch_xpu_available,
  23. logging,
  24. requires_backends,
  25. )
  26. from .benchmark_args_utils import BenchmarkArguments
  27. if is_torch_available():
  28. import torch
  29. if is_torch_xla_available():
  30. import torch_xla.core.xla_model as xm
  31. logger = logging.get_logger(__name__)
  32. @dataclass
  33. class PyTorchBenchmarkArguments(BenchmarkArguments):
  34. deprecated_args = [
  35. "no_inference",
  36. "no_cuda",
  37. "no_tpu",
  38. "no_speed",
  39. "no_memory",
  40. "no_env_print",
  41. "no_multi_process",
  42. ]
  43. def __init__(self, **kwargs):
  44. """
  45. This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
  46. deleted
  47. """
  48. for deprecated_arg in self.deprecated_args:
  49. if deprecated_arg in kwargs:
  50. positive_arg = deprecated_arg[3:]
  51. setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
  52. logger.warning(
  53. f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
  54. f" {positive_arg}={kwargs[positive_arg]}"
  55. )
  56. self.torchscript = kwargs.pop("torchscript", self.torchscript)
  57. self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
  58. self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
  59. super().__init__(**kwargs)
  60. torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
  61. torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
  62. fp16_opt_level: str = field(
  63. default="O1",
  64. metadata={
  65. "help": (
  66. "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
  67. "See details at https://nvidia.github.io/apex/amp.html"
  68. )
  69. },
  70. )
  71. @cached_property
  72. def _setup_devices(self) -> Tuple["torch.device", int]:
  73. requires_backends(self, ["torch"])
  74. logger.info("PyTorch: setting up devices")
  75. if not self.cuda:
  76. device = torch.device("cpu")
  77. n_gpu = 0
  78. elif is_torch_xla_available():
  79. device = xm.xla_device()
  80. n_gpu = 0
  81. elif is_torch_xpu_available():
  82. device = torch.device("xpu")
  83. n_gpu = torch.xpu.device_count()
  84. else:
  85. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  86. n_gpu = torch.cuda.device_count()
  87. return device, n_gpu
  88. @property
  89. def is_tpu(self):
  90. return is_torch_xla_available() and self.tpu
  91. @property
  92. def device_idx(self) -> int:
  93. requires_backends(self, ["torch"])
  94. # TODO(PVP): currently only single GPU is supported
  95. return torch.cuda.current_device()
  96. @property
  97. def device(self) -> "torch.device":
  98. requires_backends(self, ["torch"])
  99. return self._setup_devices[0]
  100. @property
  101. def n_gpu(self):
  102. requires_backends(self, ["torch"])
  103. return self._setup_devices[1]
  104. @property
  105. def is_gpu(self):
  106. return self.n_gpu > 0