| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib
- from typing import TYPE_CHECKING, Optional
- from packaging import version
- from .base import HfQuantizer
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging
- from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- class GptqHfQuantizer(HfQuantizer):
- """
- Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through
- `auto_gptq` package. Quantization is done under the hood for users if they load a non-prequantized model.
- """
- requires_calibration = False
- required_packages = ["optimum", "auto_gptq"]
- optimum_quantizer = None
- def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
- super().__init__(quantization_config, **kwargs)
- from optimum.gptq import GPTQQuantizer
- self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum())
- def validate_environment(self, *args, **kwargs):
- gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
- if not gptq_supports_cpu and not torch.cuda.is_available():
- raise RuntimeError("GPU is required to quantize or run quantize model.")
- elif not (is_optimum_available() and is_auto_gptq_available()):
- raise ImportError(
- "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
- )
- elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
- raise ImportError(
- "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
- )
- def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
- if torch_dtype is None:
- torch_dtype = torch.float16
- elif torch_dtype != torch.float16:
- logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
- return torch_dtype
- def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
- if model.__class__.main_input_name != "input_ids":
- raise RuntimeError("We can only quantize pure text model.")
- if self.pre_quantized:
- model = self.optimum_quantizer.convert_model(model)
- def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- if self.pre_quantized:
- model = self.optimum_quantizer.post_init_model(model)
- else:
- if self.quantization_config.tokenizer is None:
- self.quantization_config.tokenizer = model.name_or_path
- self.optimum_quantizer.quantize_model(model, self.quantization_config.tokenizer)
- model.config.quantization_config = GPTQConfig.from_dict(self.optimum_quantizer.to_dict())
- @property
- def is_trainable(self, model: Optional["PreTrainedModel"] = None):
- return True
- def is_serializable(self, safe_serialization=None):
- return True
|