env.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib.util
  15. import os
  16. import platform
  17. from argparse import ArgumentParser
  18. import huggingface_hub
  19. from .. import __version__ as version
  20. from ..utils import (
  21. is_accelerate_available,
  22. is_flax_available,
  23. is_safetensors_available,
  24. is_tf_available,
  25. is_torch_available,
  26. is_torch_npu_available,
  27. )
  28. from . import BaseTransformersCLICommand
  29. def info_command_factory(_):
  30. return EnvironmentCommand()
  31. def download_command_factory(args):
  32. return EnvironmentCommand(args.accelerate_config_file)
  33. class EnvironmentCommand(BaseTransformersCLICommand):
  34. @staticmethod
  35. def register_subcommand(parser: ArgumentParser):
  36. download_parser = parser.add_parser("env")
  37. download_parser.set_defaults(func=info_command_factory)
  38. download_parser.add_argument(
  39. "--accelerate-config_file",
  40. default=None,
  41. help="The accelerate config file to use for the default values in the launching script.",
  42. )
  43. download_parser.set_defaults(func=download_command_factory)
  44. def __init__(self, accelerate_config_file, *args) -> None:
  45. self._accelerate_config_file = accelerate_config_file
  46. def run(self):
  47. safetensors_version = "not installed"
  48. if is_safetensors_available():
  49. import safetensors
  50. safetensors_version = safetensors.__version__
  51. elif importlib.util.find_spec("safetensors") is not None:
  52. import safetensors
  53. safetensors_version = f"{safetensors.__version__} but is ignored because of PyTorch version too old."
  54. accelerate_version = "not installed"
  55. accelerate_config = accelerate_config_str = "not found"
  56. if is_accelerate_available():
  57. import accelerate
  58. from accelerate.commands.config import default_config_file, load_config_from_file
  59. accelerate_version = accelerate.__version__
  60. # Get the default from the config file.
  61. if self._accelerate_config_file is not None or os.path.isfile(default_config_file):
  62. accelerate_config = load_config_from_file(self._accelerate_config_file).to_dict()
  63. accelerate_config_str = (
  64. "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
  65. if isinstance(accelerate_config, dict)
  66. else f"\t{accelerate_config}"
  67. )
  68. pt_version = "not installed"
  69. pt_cuda_available = "NA"
  70. if is_torch_available():
  71. import torch
  72. pt_version = torch.__version__
  73. pt_cuda_available = torch.cuda.is_available()
  74. pt_npu_available = is_torch_npu_available()
  75. tf_version = "not installed"
  76. tf_cuda_available = "NA"
  77. if is_tf_available():
  78. import tensorflow as tf
  79. tf_version = tf.__version__
  80. try:
  81. # deprecated in v2.1
  82. tf_cuda_available = tf.test.is_gpu_available()
  83. except AttributeError:
  84. # returns list of devices, convert to bool
  85. tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
  86. flax_version = "not installed"
  87. jax_version = "not installed"
  88. jaxlib_version = "not installed"
  89. jax_backend = "NA"
  90. if is_flax_available():
  91. import flax
  92. import jax
  93. import jaxlib
  94. flax_version = flax.__version__
  95. jax_version = jax.__version__
  96. jaxlib_version = jaxlib.__version__
  97. jax_backend = jax.lib.xla_bridge.get_backend().platform
  98. info = {
  99. "`transformers` version": version,
  100. "Platform": platform.platform(),
  101. "Python version": platform.python_version(),
  102. "Huggingface_hub version": huggingface_hub.__version__,
  103. "Safetensors version": f"{safetensors_version}",
  104. "Accelerate version": f"{accelerate_version}",
  105. "Accelerate config": f"{accelerate_config_str}",
  106. "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
  107. "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
  108. "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
  109. "Jax version": f"{jax_version}",
  110. "JaxLib version": f"{jaxlib_version}",
  111. "Using distributed or parallel set-up in script?": "<fill in>",
  112. }
  113. if is_torch_available():
  114. if pt_cuda_available:
  115. info["Using GPU in script?"] = "<fill in>"
  116. info["GPU type"] = torch.cuda.get_device_name()
  117. elif pt_npu_available:
  118. info["Using NPU in script?"] = "<fill in>"
  119. info["NPU type"] = torch.npu.get_device_name()
  120. info["CANN version"] = torch.version.cann
  121. print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
  122. print(self.format_dict(info))
  123. return info
  124. @staticmethod
  125. def format_dict(d):
  126. return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"