quantizer_bnb_4bit.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. from functools import cached_property
  16. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  17. from packaging import version
  18. from .base import HfQuantizer
  19. from .quantizers_utils import get_module_from_name
  20. if TYPE_CHECKING:
  21. from ..modeling_utils import PreTrainedModel
  22. from ..utils import (
  23. ACCELERATE_MIN_VERSION,
  24. is_accelerate_available,
  25. is_bitsandbytes_available,
  26. is_torch_available,
  27. is_torch_xpu_available,
  28. logging,
  29. )
  30. if is_torch_available():
  31. import torch
  32. from ..pytorch_utils import Conv1D
  33. logger = logging.get_logger(__name__)
  34. class Bnb4BitHfQuantizer(HfQuantizer):
  35. """
  36. 4-bit quantization from bitsandbytes.py quantization method:
  37. before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
  38. layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call
  39. saving:
  40. from state dict, as usual; saves weights and `quant_state` components
  41. loading:
  42. need to locate `quant_state` components and pass to Param4bit constructor
  43. """
  44. use_keep_in_fp32_modules = True
  45. requires_parameters_quantization = True
  46. requires_calibration = False
  47. required_packages = ["bitsandbytes", "accelerate"]
  48. def __init__(self, quantization_config, **kwargs):
  49. super().__init__(quantization_config, **kwargs)
  50. if self.quantization_config.llm_int8_skip_modules is not None:
  51. self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
  52. def validate_environment(self, *args, **kwargs):
  53. if not is_accelerate_available():
  54. raise ImportError(
  55. f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  56. )
  57. if not is_bitsandbytes_available():
  58. raise ImportError(
  59. "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
  60. )
  61. from ..integrations import validate_bnb_backend_availability
  62. from ..utils import is_bitsandbytes_multi_backend_available
  63. bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
  64. validate_bnb_backend_availability(raise_exception=True)
  65. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  66. raise ValueError(
  67. "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
  68. " sure the weights are in PyTorch format."
  69. )
  70. device_map = kwargs.get("device_map", None)
  71. if (
  72. device_map is not None
  73. and isinstance(device_map, dict)
  74. and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
  75. ):
  76. device_map_without_lm_head = {
  77. key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
  78. }
  79. if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
  80. pass
  81. elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
  82. raise ValueError(
  83. "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
  84. "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
  85. "in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to "
  86. "`from_pretrained`. Check "
  87. "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
  88. "for more details. "
  89. )
  90. if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"):
  91. raise ValueError(
  92. "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training"
  93. " make sure you have the latest version of `bitsandbytes` installed"
  94. )
  95. def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
  96. if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
  97. from accelerate.utils import CustomDtype
  98. if target_dtype != torch.int8:
  99. logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
  100. return CustomDtype.INT4
  101. else:
  102. raise ValueError(
  103. "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
  104. " the appropriate device map, you should upgrade your `accelerate` library,"
  105. "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map"
  106. "calculation. You may encounter unexpected behavior, or pass your own device map"
  107. )
  108. def check_quantized_param(
  109. self,
  110. model: "PreTrainedModel",
  111. param_value: "torch.Tensor",
  112. param_name: str,
  113. state_dict: Dict[str, Any],
  114. **kwargs,
  115. ) -> bool:
  116. import bitsandbytes as bnb
  117. module, tensor_name = get_module_from_name(model, param_name)
  118. if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
  119. # Add here check for loaded components' dtypes once serialization is implemented
  120. return True
  121. elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
  122. # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
  123. # but it would wrongly use uninitialized weight there.
  124. return True
  125. else:
  126. return False
  127. def create_quantized_param(
  128. self,
  129. model: "PreTrainedModel",
  130. param_value: "torch.Tensor",
  131. param_name: str,
  132. target_device: "torch.device",
  133. state_dict: Dict[str, Any],
  134. unexpected_keys: Optional[List[str]] = None,
  135. ):
  136. """
  137. combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device()
  138. """
  139. import bitsandbytes as bnb
  140. module, tensor_name = get_module_from_name(model, param_name)
  141. if tensor_name not in module._parameters:
  142. raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
  143. old_value = getattr(module, tensor_name)
  144. if tensor_name == "bias":
  145. if param_value is None:
  146. new_value = old_value.to(target_device)
  147. else:
  148. new_value = param_value.to(target_device)
  149. new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
  150. module._parameters[tensor_name] = new_value
  151. return
  152. if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
  153. raise ValueError("this function only loads `Linear4bit components`")
  154. if (
  155. old_value.device == torch.device("meta")
  156. and target_device not in ["meta", torch.device("meta")]
  157. and param_value is None
  158. ):
  159. raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
  160. # construct `new_value` for the module._parameters[tensor_name]:
  161. if self.pre_quantized:
  162. # 4bit loading. Collecting components for restoring quantized weight
  163. # This can be expanded to make a universal call for any quantized weight loading
  164. if not self.is_serializable:
  165. raise ValueError(
  166. "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
  167. "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
  168. )
  169. if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
  170. param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
  171. ):
  172. raise ValueError(
  173. f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
  174. )
  175. quantized_stats = {}
  176. for k, v in state_dict.items():
  177. if param_name + "." in k:
  178. quantized_stats[k] = v
  179. if unexpected_keys is not None and k in unexpected_keys:
  180. unexpected_keys.remove(k)
  181. param_kwargs = {}
  182. if self.is_bnb_supports_quant_storage_module:
  183. param_kwargs["module"] = module
  184. new_value = bnb.nn.Params4bit.from_prequantized(
  185. data=param_value,
  186. quantized_stats=quantized_stats,
  187. requires_grad=False,
  188. device=target_device,
  189. **param_kwargs,
  190. )
  191. else:
  192. new_value = param_value.to("cpu")
  193. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  194. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  195. if issubclass(module.source_cls, Conv1D):
  196. new_value = new_value.T
  197. kwargs = old_value.__dict__
  198. new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
  199. module._parameters[tensor_name] = new_value
  200. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
  201. def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
  202. # need more space for buffers that are created during quantization
  203. max_memory = {key: val * 0.90 for key, val in max_memory.items()}
  204. return max_memory
  205. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_torch_dtype
  206. def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
  207. if torch_dtype is None:
  208. # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
  209. logger.info(
  210. "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
  211. "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
  212. "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
  213. " torch_dtype=torch.float16 to remove this warning.",
  214. torch_dtype,
  215. )
  216. torch_dtype = torch.float16
  217. return torch_dtype
  218. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map
  219. def update_device_map(self, device_map):
  220. if device_map is None:
  221. if torch.cuda.is_available():
  222. device_map = {"": torch.cuda.current_device()}
  223. elif is_torch_xpu_available():
  224. device_map = {"": f"xpu:{torch.xpu.current_device()}"}
  225. else:
  226. device_map = {"": "cpu"}
  227. logger.info(
  228. "The device_map was not initialized. "
  229. f"Setting device_map to {device_map}. "
  230. "If you want to use the model for inference, please set device_map ='auto' "
  231. )
  232. return device_map
  233. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading
  234. def _process_model_before_weight_loading(
  235. self,
  236. model: "PreTrainedModel",
  237. device_map,
  238. keep_in_fp32_modules: List[str] = [],
  239. **kwargs,
  240. ):
  241. from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
  242. llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
  243. # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
  244. if self.quantization_config.llm_int8_skip_modules is None:
  245. self.modules_to_not_convert = get_keys_to_not_convert(model)
  246. else:
  247. self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
  248. if not isinstance(self.modules_to_not_convert, list):
  249. self.modules_to_not_convert = [self.modules_to_not_convert]
  250. self.modules_to_not_convert.extend(keep_in_fp32_modules)
  251. # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
  252. if isinstance(device_map, dict) and len(device_map.keys()) > 1:
  253. keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
  254. if len(keys_on_cpu) > 0 and not llm_int8_enable_fp32_cpu_offload:
  255. raise ValueError(
  256. "If you want to offload some keys to `cpu` or `disk`, you need to set "
  257. "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
  258. " converted to 8-bit but kept in 32-bit."
  259. )
  260. self.modules_to_not_convert.extend(keys_on_cpu)
  261. model = replace_with_bnb_linear(
  262. model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
  263. )
  264. # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
  265. model.config.quantization_config = self.quantization_config
  266. # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
  267. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  268. model.is_loaded_in_4bit = True
  269. model.is_4bit_serializable = self.is_serializable()
  270. return model
  271. def is_serializable(self, safe_serialization=None):
  272. _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
  273. if not _is_4bit_serializable:
  274. logger.warning(
  275. "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
  276. "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
  277. )
  278. return False
  279. return True
  280. @cached_property
  281. def is_bnb_supports_quant_storage_module(self) -> bool:
  282. """
  283. determines if the current version of bitsandbytes supports
  284. the `module` parameter in `Params4bit.from_prequantized`
  285. :return:
  286. """
  287. return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.43.3")
  288. @property
  289. def is_trainable(self) -> bool:
  290. return True
  291. def _dequantize(self, model):
  292. from ..integrations import dequantize_and_replace
  293. model = dequantize_and_replace(
  294. model, self.modules_to_not_convert, quantization_config=self.quantization_config
  295. )
  296. return model