auto.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from typing import Dict, Optional, Union
  16. from ..models.auto.configuration_auto import AutoConfig
  17. from ..utils.quantization_config import (
  18. AqlmConfig,
  19. AwqConfig,
  20. BitNetConfig,
  21. BitsAndBytesConfig,
  22. CompressedTensorsConfig,
  23. EetqConfig,
  24. FbgemmFp8Config,
  25. GPTQConfig,
  26. HqqConfig,
  27. QuantizationConfigMixin,
  28. QuantizationMethod,
  29. QuantoConfig,
  30. TorchAoConfig,
  31. )
  32. from .quantizer_aqlm import AqlmHfQuantizer
  33. from .quantizer_awq import AwqQuantizer
  34. from .quantizer_bitnet import BitNetHfQuantizer
  35. from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
  36. from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
  37. from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
  38. from .quantizer_eetq import EetqHfQuantizer
  39. from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
  40. from .quantizer_gptq import GptqHfQuantizer
  41. from .quantizer_hqq import HqqHfQuantizer
  42. from .quantizer_quanto import QuantoHfQuantizer
  43. from .quantizer_torchao import TorchAoHfQuantizer
  44. AUTO_QUANTIZER_MAPPING = {
  45. "awq": AwqQuantizer,
  46. "bitsandbytes_4bit": Bnb4BitHfQuantizer,
  47. "bitsandbytes_8bit": Bnb8BitHfQuantizer,
  48. "gptq": GptqHfQuantizer,
  49. "aqlm": AqlmHfQuantizer,
  50. "quanto": QuantoHfQuantizer,
  51. "eetq": EetqHfQuantizer,
  52. "hqq": HqqHfQuantizer,
  53. "compressed-tensors": CompressedTensorsHfQuantizer,
  54. "fbgemm_fp8": FbgemmFp8HfQuantizer,
  55. "torchao": TorchAoHfQuantizer,
  56. "bitnet": BitNetHfQuantizer,
  57. }
  58. AUTO_QUANTIZATION_CONFIG_MAPPING = {
  59. "awq": AwqConfig,
  60. "bitsandbytes_4bit": BitsAndBytesConfig,
  61. "bitsandbytes_8bit": BitsAndBytesConfig,
  62. "eetq": EetqConfig,
  63. "gptq": GPTQConfig,
  64. "aqlm": AqlmConfig,
  65. "quanto": QuantoConfig,
  66. "hqq": HqqConfig,
  67. "compressed-tensors": CompressedTensorsConfig,
  68. "fbgemm_fp8": FbgemmFp8Config,
  69. "torchao": TorchAoConfig,
  70. "bitnet": BitNetConfig,
  71. }
  72. class AutoQuantizationConfig:
  73. """
  74. The Auto-HF quantization config class that takes care of automatically dispatching to the correct
  75. quantization config given a quantization config stored in a dictionary.
  76. """
  77. @classmethod
  78. def from_dict(cls, quantization_config_dict: Dict):
  79. quant_method = quantization_config_dict.get("quant_method", None)
  80. # We need a special care for bnb models to make sure everything is BC ..
  81. if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
  82. suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
  83. quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
  84. elif quant_method is None:
  85. raise ValueError(
  86. "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
  87. )
  88. if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
  89. raise ValueError(
  90. f"Unknown quantization type, got {quant_method} - supported types are:"
  91. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  92. )
  93. target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
  94. return target_cls.from_dict(quantization_config_dict)
  95. @classmethod
  96. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  97. model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  98. if getattr(model_config, "quantization_config", None) is None:
  99. raise ValueError(
  100. f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
  101. )
  102. quantization_config_dict = model_config.quantization_config
  103. quantization_config = cls.from_dict(quantization_config_dict)
  104. # Update with potential kwargs that are passed through from_pretrained.
  105. quantization_config.update(**kwargs)
  106. return quantization_config
  107. class AutoHfQuantizer:
  108. """
  109. The Auto-HF quantizer class that takes care of automatically instantiating to the correct
  110. `HfQuantizer` given the `QuantizationConfig`.
  111. """
  112. @classmethod
  113. def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
  114. # Convert it to a QuantizationConfig if the q_config is a dict
  115. if isinstance(quantization_config, dict):
  116. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  117. quant_method = quantization_config.quant_method
  118. # Again, we need a special care for bnb as we have a single quantization config
  119. # class for both 4-bit and 8-bit quantization
  120. if quant_method == QuantizationMethod.BITS_AND_BYTES:
  121. if quantization_config.load_in_8bit:
  122. quant_method += "_8bit"
  123. else:
  124. quant_method += "_4bit"
  125. if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
  126. raise ValueError(
  127. f"Unknown quantization type, got {quant_method} - supported types are:"
  128. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  129. )
  130. target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
  131. return target_cls(quantization_config, **kwargs)
  132. @classmethod
  133. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  134. quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  135. return cls.from_config(quantization_config)
  136. @classmethod
  137. def merge_quantization_configs(
  138. cls,
  139. quantization_config: Union[dict, QuantizationConfigMixin],
  140. quantization_config_from_args: Optional[QuantizationConfigMixin],
  141. ):
  142. """
  143. handles situations where both quantization_config from args and quantization_config from model config are present.
  144. """
  145. if quantization_config_from_args is not None:
  146. warning_msg = (
  147. "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
  148. " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
  149. )
  150. else:
  151. warning_msg = ""
  152. if isinstance(quantization_config, dict):
  153. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  154. if (
  155. isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config))
  156. and quantization_config_from_args is not None
  157. ):
  158. # special case for GPTQ / AWQ / FbgemmFp8 config collision
  159. loading_attr_dict = quantization_config_from_args.get_loading_attributes()
  160. for attr, val in loading_attr_dict.items():
  161. setattr(quantization_config, attr, val)
  162. warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
  163. if warning_msg != "":
  164. warnings.warn(warning_msg)
  165. return quantization_config