quantizer_eetq.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Any, Dict, List, Optional
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils import is_accelerate_available, is_eetq_available, is_torch_available, logging
  19. from .quantizers_utils import get_module_from_name
  20. if is_torch_available():
  21. import torch
  22. logger = logging.get_logger(__name__)
  23. class EetqHfQuantizer(HfQuantizer):
  24. """
  25. 8-bit quantization from EETQ quantization method:
  26. before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the
  27. layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call
  28. """
  29. requires_parameters_quantization = True
  30. requires_calibration = False
  31. required_packages = ["eetq", "accelerate"]
  32. def __init__(self, quantization_config, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. self.quantization_config = quantization_config
  35. def validate_environment(self, *args, **kwargs):
  36. if not is_eetq_available():
  37. raise ImportError(
  38. "Using `eetq` 8-bit quantization requires eetq."
  39. "Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ"
  40. )
  41. if not is_accelerate_available():
  42. raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)")
  43. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  44. raise ValueError(
  45. "Converting into 8-bit weights from tf/flax weights is currently not supported, please make"
  46. " sure the weights are in PyTorch format."
  47. )
  48. if not torch.cuda.is_available():
  49. raise RuntimeError("No GPU found. A GPU is needed for quantization.")
  50. device_map = kwargs.get("device_map", None)
  51. if device_map is None:
  52. logger.warning_once(
  53. "You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set "
  54. "your model on a GPU device in order to run your model."
  55. )
  56. elif device_map is not None:
  57. if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
  58. raise ValueError(
  59. "You are attempting to load an EETQ model with a device_map that contains a CPU or disk device."
  60. " This is not supported. Please remove the CPU or disk device from the device_map."
  61. )
  62. def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
  63. if torch_dtype is None:
  64. torch_dtype = torch.float16
  65. logger.info(
  66. "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
  67. "requirements of `eetq` to enable model loading in 8-bit. "
  68. "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
  69. " torch_dtype=torch.float16 to remove this warning.",
  70. torch_dtype,
  71. )
  72. elif torch_dtype != torch.float16:
  73. logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with EETQ.")
  74. return torch_dtype
  75. def check_quantized_param(
  76. self,
  77. model: "PreTrainedModel",
  78. param_value: "torch.Tensor",
  79. param_name: str,
  80. state_dict: Dict[str, Any],
  81. **kwargs,
  82. ):
  83. from eetq import EetqLinear
  84. module, tensor_name = get_module_from_name(model, param_name)
  85. if isinstance(module, EetqLinear):
  86. if self.pre_quantized or tensor_name == "bias":
  87. if tensor_name == "weight" and param_value.dtype != torch.int8:
  88. raise ValueError("Expect quantized weights but got an unquantized weight")
  89. return False
  90. else:
  91. if tensor_name == "weight_scale":
  92. raise ValueError("Expect unquantized weights but got a quantized weight_scale")
  93. return True
  94. return False
  95. def create_quantized_param(
  96. self,
  97. model: "PreTrainedModel",
  98. param_value: "torch.Tensor",
  99. param_name: str,
  100. target_device: "torch.device",
  101. state_dict: Dict[str, Any],
  102. unexpected_keys: Optional[List[str]] = None,
  103. ):
  104. """
  105. quantizes weights into qweight and weight_scales
  106. """
  107. from eetq import quantize_and_preprocess_weights
  108. module, tensor_name = get_module_from_name(model, param_name)
  109. new_value, weight_scale = quantize_and_preprocess_weights(param_value)
  110. module._buffers[tensor_name] = new_value.to(target_device)
  111. module.register("weight_scales", weight_scale.to(target_device))
  112. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  113. return model
  114. def _process_model_before_weight_loading(
  115. self,
  116. model: "PreTrainedModel",
  117. device_map,
  118. keep_in_fp32_modules: List[str] = [],
  119. **kwargs,
  120. ):
  121. from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear
  122. self.modules_to_not_convert = get_keys_to_not_convert(model)
  123. if self.quantization_config.modules_to_not_convert is not None:
  124. self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
  125. model = replace_with_eetq_linear(
  126. model,
  127. modules_to_not_convert=self.modules_to_not_convert,
  128. quantization_config=self.quantization_config,
  129. pre_quantized=self.pre_quantized,
  130. )
  131. model.config.quantization_config = self.quantization_config
  132. def is_serializable(self, safe_serialization=None):
  133. return True
  134. @property
  135. def is_trainable(self) -> bool:
  136. return True