quantizer_fbgemm_fp8.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. from typing import TYPE_CHECKING, Any, Dict, List, Optional
  16. from packaging import version
  17. from .base import HfQuantizer
  18. if TYPE_CHECKING:
  19. from ..modeling_utils import PreTrainedModel
  20. from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
  21. from .quantizers_utils import get_module_from_name
  22. if is_torch_available():
  23. import torch
  24. logger = logging.get_logger(__name__)
  25. class FbgemmFp8HfQuantizer(HfQuantizer):
  26. """
  27. FP8 quantization using fbgemm kernels
  28. """
  29. requires_parameters_quantization = True
  30. requires_calibration = False
  31. required_packages = ["fbgemm-gpu", "accelerate"]
  32. def __init__(self, quantization_config, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. self.quantization_config = quantization_config
  35. def validate_environment(self, *args, **kwargs):
  36. if not is_torch_available() or version.parse(importlib.metadata.version("torch")) < version.parse("2.1.0"):
  37. raise ImportError(
  38. "Using fbgemm fp8 quantization requires torch > 2.1.0"
  39. "Please install the latest version of torch ( pip install --upgrade torch )"
  40. )
  41. if not is_fbgemm_gpu_available():
  42. raise ImportError(
  43. "Using fbgemm fp8 quantization requires fbgemm-gpu library"
  44. "Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
  45. )
  46. if not is_accelerate_available("0.32.2"):
  47. raise ImportError(
  48. "Loading an FP8 quantized model requires accelerate > 0.32.1 (`pip install --upgrade accelerate`)"
  49. )
  50. if not torch.cuda.is_available():
  51. raise RuntimeError("Using FP8 quantized models with fbgemm kernels requires a GPU")
  52. compute_capability = torch.cuda.get_device_capability()
  53. major, minor = compute_capability
  54. if major < 9:
  55. raise ValueError(
  56. "FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
  57. )
  58. device_map = kwargs.get("device_map", None)
  59. if device_map is None:
  60. logger.warning_once(
  61. "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
  62. "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
  63. )
  64. elif device_map is not None:
  65. if (
  66. not self.pre_quantized
  67. and isinstance(device_map, dict)
  68. and ("cpu" in device_map.values() or "disk" in device_map.values())
  69. ):
  70. raise ValueError(
  71. "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device."
  72. "This is not supported when the model is quantized on the fly. "
  73. "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
  74. )
  75. def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
  76. if torch_dtype is None:
  77. torch_dtype = torch.bfloat16
  78. logger.info(
  79. "Overriding torch_dtype=%s with `torch_dtype=torch.bloat16` due to "
  80. "requirements of `fbgemm-gpu` to enable model loading in fp8. "
  81. "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
  82. " torch_dtype=torch.bfloat16 to remove this warning.",
  83. torch_dtype,
  84. )
  85. elif torch_dtype == torch.float16:
  86. raise ValueError(
  87. "You cannot use FP8 with torch_dtype=torch.float16."
  88. "We recommend you passing torch_dtype=torch.bfloat16"
  89. )
  90. return torch_dtype
  91. def check_quantized_param(
  92. self,
  93. model: "PreTrainedModel",
  94. param_value: "torch.Tensor",
  95. param_name: str,
  96. state_dict: Dict[str, Any],
  97. **kwargs,
  98. ):
  99. from ..integrations import FbgemmFp8Linear
  100. module, tensor_name = get_module_from_name(model, param_name)
  101. if isinstance(module, FbgemmFp8Linear):
  102. if self.pre_quantized or tensor_name == "bias":
  103. if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
  104. raise ValueError("Expect quantized weights but got an unquantized weight")
  105. return False
  106. else:
  107. if tensor_name == "weight_scale":
  108. raise ValueError("Expect unquantized weights but got a quantized weight_scale")
  109. return True
  110. return False
  111. def create_quantized_param(
  112. self,
  113. model: "PreTrainedModel",
  114. param_value: "torch.Tensor",
  115. param_name: str,
  116. target_device: "torch.device",
  117. state_dict: Dict[str, Any],
  118. unexpected_keys: Optional[List[str]] = None,
  119. ):
  120. """
  121. Quantizes weights into weight and weight_scale
  122. """
  123. new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
  124. module, tensor_name = get_module_from_name(model, param_name)
  125. module._buffers[tensor_name] = new_value.to(target_device)
  126. # to have the right output shape -> (out_features, 1)
  127. module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
  128. if unexpected_keys is not None and param_name in unexpected_keys:
  129. unexpected_keys.remove(param_name)
  130. del param_name
  131. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  132. return model
  133. def _process_model_before_weight_loading(
  134. self,
  135. model: "PreTrainedModel",
  136. device_map,
  137. keep_in_fp32_modules: List[str] = [],
  138. **kwargs,
  139. ):
  140. from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
  141. self.modules_to_not_convert = get_keys_to_not_convert(model)
  142. if self.quantization_config.modules_to_not_convert is not None:
  143. self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
  144. model = replace_with_fbgemm_fp8_linear(
  145. model,
  146. modules_to_not_convert=self.modules_to_not_convert,
  147. quantization_config=self.quantization_config,
  148. pre_quantized=self.pre_quantized,
  149. )
  150. model.config.quantization_config = self.quantization_config
  151. def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
  152. from ..integrations import FbgemmFp8Linear
  153. not_missing_keys = []
  154. for name, module in model.named_modules():
  155. if isinstance(module, FbgemmFp8Linear):
  156. for missing in missing_keys:
  157. if (
  158. (name in missing or name in f"{prefix}.{missing}")
  159. and not missing.endswith(".weight")
  160. and not missing.endswith(".bias")
  161. ):
  162. not_missing_keys.append(missing)
  163. return [k for k in missing_keys if k not in not_missing_keys]
  164. def is_serializable(self, safe_serialization=None):
  165. return True
  166. @property
  167. def is_trainable(self) -> bool:
  168. return False