bitsandbytes.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import importlib.metadata
  2. import inspect
  3. import warnings
  4. from copy import deepcopy
  5. from inspect import signature
  6. from packaging import version
  7. from ..utils import (
  8. get_available_devices,
  9. is_accelerate_available,
  10. is_bitsandbytes_available,
  11. is_bitsandbytes_multi_backend_available,
  12. is_ipex_available,
  13. is_torch_available,
  14. logging,
  15. )
  16. if is_bitsandbytes_available():
  17. import bitsandbytes as bnb
  18. import torch
  19. import torch.nn as nn
  20. from ..pytorch_utils import Conv1D
  21. if is_accelerate_available():
  22. import accelerate
  23. from accelerate import init_empty_weights
  24. from accelerate.hooks import add_hook_to_module, remove_hook_from_module
  25. from accelerate.utils import find_tied_parameters
  26. logger = logging.get_logger(__name__)
  27. def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
  28. """
  29. A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
  30. `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
  31. function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
  32. class `Int8Params` from `bitsandbytes`.
  33. Args:
  34. module (`torch.nn.Module`):
  35. The module in which the tensor we want to move lives.
  36. tensor_name (`str`):
  37. The full name of the parameter/buffer.
  38. device (`int`, `str` or `torch.device`):
  39. The device on which to set the tensor.
  40. value (`torch.Tensor`, *optional*):
  41. The value of the tensor (useful when going from the meta device to any other device).
  42. quantized_stats (`dict[str, Any]`, *optional*):
  43. Dict with items for either 4-bit or 8-bit serialization
  44. """
  45. # Recurse if needed
  46. if "." in tensor_name:
  47. splits = tensor_name.split(".")
  48. for split in splits[:-1]:
  49. new_module = getattr(module, split)
  50. if new_module is None:
  51. raise ValueError(f"{module} has no attribute {split}.")
  52. module = new_module
  53. tensor_name = splits[-1]
  54. if tensor_name not in module._parameters and tensor_name not in module._buffers:
  55. raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
  56. is_buffer = tensor_name in module._buffers
  57. old_value = getattr(module, tensor_name)
  58. if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
  59. raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
  60. prequantized_loading = quantized_stats is not None
  61. if is_buffer or not is_bitsandbytes_available():
  62. is_8bit = False
  63. is_4bit = False
  64. else:
  65. is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
  66. is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
  67. if is_8bit or is_4bit:
  68. param = module._parameters[tensor_name]
  69. if param.device.type != "cuda":
  70. if value is None:
  71. new_value = old_value.to(device)
  72. elif isinstance(value, torch.Tensor):
  73. new_value = value.to("cpu")
  74. else:
  75. new_value = torch.tensor(value, device="cpu")
  76. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  77. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  78. if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
  79. new_value = new_value.T
  80. kwargs = old_value.__dict__
  81. if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
  82. raise ValueError(
  83. f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
  84. )
  85. if is_8bit:
  86. is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
  87. "0.37.2"
  88. )
  89. if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
  90. raise ValueError(
  91. "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
  92. "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
  93. )
  94. new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
  95. if prequantized_loading:
  96. setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
  97. elif is_4bit:
  98. if prequantized_loading:
  99. is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
  100. "0.41.3"
  101. )
  102. if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
  103. raise ValueError(
  104. "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
  105. "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
  106. )
  107. new_value = bnb.nn.Params4bit.from_prequantized(
  108. data=new_value,
  109. quantized_stats=quantized_stats,
  110. requires_grad=False,
  111. device=device,
  112. **kwargs,
  113. )
  114. else:
  115. new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
  116. module._parameters[tensor_name] = new_value
  117. else:
  118. if value is None:
  119. new_value = old_value.to(device)
  120. elif isinstance(value, torch.Tensor):
  121. new_value = value.to(device)
  122. else:
  123. new_value = torch.tensor(value, device=device)
  124. if is_buffer:
  125. module._buffers[tensor_name] = new_value
  126. else:
  127. new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
  128. module._parameters[tensor_name] = new_value
  129. def _replace_with_bnb_linear(
  130. model,
  131. modules_to_not_convert=None,
  132. current_key_name=None,
  133. quantization_config=None,
  134. has_been_replaced=False,
  135. ):
  136. """
  137. Private method that wraps the recursion for module replacement.
  138. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
  139. """
  140. for name, module in model.named_children():
  141. if current_key_name is None:
  142. current_key_name = []
  143. current_key_name.append(name)
  144. if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
  145. # Check if the current key is not in the `modules_to_not_convert`
  146. current_key_name_str = ".".join(current_key_name)
  147. if not any(
  148. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  149. ):
  150. with init_empty_weights():
  151. if isinstance(module, Conv1D):
  152. in_features, out_features = module.weight.shape
  153. else:
  154. in_features = module.in_features
  155. out_features = module.out_features
  156. if quantization_config.quantization_method() == "llm_int8":
  157. model._modules[name] = bnb.nn.Linear8bitLt(
  158. in_features,
  159. out_features,
  160. module.bias is not None,
  161. has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
  162. threshold=quantization_config.llm_int8_threshold,
  163. )
  164. has_been_replaced = True
  165. else:
  166. if (
  167. quantization_config.llm_int8_skip_modules is not None
  168. and name in quantization_config.llm_int8_skip_modules
  169. ):
  170. pass
  171. else:
  172. extra_kwargs = (
  173. {"quant_storage": quantization_config.bnb_4bit_quant_storage}
  174. if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
  175. else {}
  176. )
  177. model._modules[name] = bnb.nn.Linear4bit(
  178. in_features,
  179. out_features,
  180. module.bias is not None,
  181. quantization_config.bnb_4bit_compute_dtype,
  182. compress_statistics=quantization_config.bnb_4bit_use_double_quant,
  183. quant_type=quantization_config.bnb_4bit_quant_type,
  184. **extra_kwargs,
  185. )
  186. has_been_replaced = True
  187. # Store the module class in case we need to transpose the weight later
  188. model._modules[name].source_cls = type(module)
  189. # Force requires grad to False to avoid unexpected errors
  190. model._modules[name].requires_grad_(False)
  191. if len(list(module.children())) > 0:
  192. _, has_been_replaced = _replace_with_bnb_linear(
  193. module,
  194. modules_to_not_convert,
  195. current_key_name,
  196. quantization_config,
  197. has_been_replaced=has_been_replaced,
  198. )
  199. # Remove the last key for recursion
  200. current_key_name.pop(-1)
  201. return model, has_been_replaced
  202. def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
  203. """
  204. A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
  205. library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
  206. 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
  207. version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
  208. bitsandbytes`
  209. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  210. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  211. CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
  212. matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
  213. (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
  214. predictive degradation is possible for very large models (>=176B parameters).
  215. Parameters:
  216. model (`torch.nn.Module`):
  217. Input model or `torch.nn.Module` as the function is run recursively.
  218. modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
  219. Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
  220. for numerical stability reasons.
  221. current_key_name (`List[`str`]`, *optional*):
  222. An array to track the current key of the recursion. This is used to check whether the current key (part of
  223. it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
  224. `disk`).
  225. quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
  226. To configure and manage settings related to quantization, a technique used to compress neural network models
  227. by reducing the precision of the weights and activations, thus making models more efficient in terms of both
  228. storage and computation.
  229. """
  230. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  231. model, has_been_replaced = _replace_with_bnb_linear(
  232. model, modules_to_not_convert, current_key_name, quantization_config
  233. )
  234. if not has_been_replaced:
  235. logger.warning(
  236. "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
  237. " Please double check your model architecture, or submit an issue on github if you think this is"
  238. " a bug."
  239. )
  240. return model
  241. # For backward compatibility
  242. def replace_8bit_linear(*args, **kwargs):
  243. warnings.warn(
  244. "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
  245. FutureWarning,
  246. )
  247. return replace_with_bnb_linear(*args, **kwargs)
  248. # For backward compatiblity
  249. def set_module_8bit_tensor_to_device(*args, **kwargs):
  250. warnings.warn(
  251. "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
  252. FutureWarning,
  253. )
  254. return set_module_quantized_tensor_to_device(*args, **kwargs)
  255. def get_keys_to_not_convert(model):
  256. r"""
  257. An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
  258. we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
  259. to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
  260. int8.
  261. Parameters:
  262. model (`torch.nn.Module`):
  263. Input model
  264. """
  265. # Create a copy of the model and tie the weights, then
  266. # check if it contains tied weights
  267. tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
  268. tied_model.tie_weights()
  269. tied_params = find_tied_parameters(tied_model)
  270. # For compatibility with Accelerate < 0.18
  271. if isinstance(tied_params, dict):
  272. tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
  273. else:
  274. tied_keys = sum(tied_params, [])
  275. has_tied_params = len(tied_keys) > 0
  276. # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
  277. if not has_tied_params:
  278. output_emb = model.get_output_embeddings()
  279. if output_emb is not None:
  280. list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
  281. return list_last_module
  282. # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
  283. list_modules = list(model.named_parameters())
  284. list_last_module = [list_modules[-1][0]]
  285. # add last module together with tied weights
  286. intersection = set(list_last_module) - set(tied_keys)
  287. list_untouched = list(set(tied_keys)) + list(intersection)
  288. # remove ".weight" from the keys
  289. names_to_remove = [".weight", ".bias"]
  290. filtered_module_names = []
  291. for name in list_untouched:
  292. for name_to_remove in names_to_remove:
  293. if name_to_remove in name:
  294. name = name.replace(name_to_remove, "")
  295. filtered_module_names.append(name)
  296. return filtered_module_names
  297. # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
  298. def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
  299. """
  300. Helper function to dequantize 4bit or 8bit bnb weights.
  301. If the weight is not a bnb quantized weight, it will be returned as is.
  302. """
  303. if not isinstance(weight, torch.nn.Parameter):
  304. raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
  305. cls_name = weight.__class__.__name__
  306. if cls_name not in ("Params4bit", "Int8Params"):
  307. return weight
  308. if cls_name == "Params4bit":
  309. output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
  310. logger.warning_once(
  311. f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
  312. )
  313. return output_tensor.to(dtype)
  314. if state.SCB is None:
  315. state.SCB = weight.SCB
  316. im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
  317. im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
  318. im, Sim = bnb.functional.transform(im, "col32")
  319. if state.CxB is None:
  320. state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
  321. out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
  322. return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype)
  323. def _create_accelerate_new_hook(old_hook):
  324. r"""
  325. Creates a new hook based on the old hook. Use it only if you know what you are doing !
  326. This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
  327. with some changes
  328. """
  329. old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
  330. old_hook_attr = old_hook.__dict__
  331. filtered_old_hook_attr = {}
  332. old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
  333. for k in old_hook_attr.keys():
  334. if k in old_hook_init_signature.parameters:
  335. filtered_old_hook_attr[k] = old_hook_attr[k]
  336. new_hook = old_hook_cls(**filtered_old_hook_attr)
  337. return new_hook
  338. def _dequantize_and_replace(
  339. model,
  340. dtype,
  341. modules_to_not_convert=None,
  342. current_key_name=None,
  343. quantization_config=None,
  344. has_been_replaced=False,
  345. ):
  346. """
  347. Converts a quantized model into its dequantized original version. The newly converted model will have
  348. some performance drop compared to the original model before quantization - use it only for specific usecases
  349. such as QLoRA adapters merging.
  350. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
  351. """
  352. quant_method = quantization_config.quantization_method()
  353. target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
  354. for name, module in model.named_children():
  355. if current_key_name is None:
  356. current_key_name = []
  357. current_key_name.append(name)
  358. if isinstance(module, target_cls) and name not in modules_to_not_convert:
  359. # Check if the current key is not in the `modules_to_not_convert`
  360. current_key_name_str = ".".join(current_key_name)
  361. if not any(
  362. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  363. ):
  364. bias = getattr(module, "bias", None)
  365. device = module.weight.device
  366. with init_empty_weights():
  367. new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
  368. if quant_method == "llm_int8":
  369. state = module.state
  370. else:
  371. state = None
  372. new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state))
  373. if bias is not None:
  374. new_module.bias = bias
  375. # Create a new hook and attach it in case we use accelerate
  376. if hasattr(module, "_hf_hook"):
  377. old_hook = module._hf_hook
  378. new_hook = _create_accelerate_new_hook(old_hook)
  379. remove_hook_from_module(module)
  380. add_hook_to_module(new_module, new_hook)
  381. new_module.to(device)
  382. model._modules[name] = new_module
  383. has_been_replaced = True
  384. if len(list(module.children())) > 0:
  385. _, has_been_replaced = _dequantize_and_replace(
  386. module,
  387. dtype,
  388. modules_to_not_convert,
  389. current_key_name,
  390. quantization_config,
  391. has_been_replaced=has_been_replaced,
  392. )
  393. # Remove the last key for recursion
  394. current_key_name.pop(-1)
  395. return model, has_been_replaced
  396. def dequantize_and_replace(
  397. model,
  398. modules_to_not_convert=None,
  399. quantization_config=None,
  400. ):
  401. model, has_been_replaced = _dequantize_and_replace(
  402. model,
  403. model.dtype,
  404. modules_to_not_convert=modules_to_not_convert,
  405. quantization_config=quantization_config,
  406. )
  407. if not has_been_replaced:
  408. logger.warning(
  409. "For some reason the model has not been properly dequantized. You might see unexpected behavior."
  410. )
  411. return model
  412. def _validate_bnb_multi_backend_availability(raise_exception):
  413. import bitsandbytes as bnb
  414. bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
  415. available_devices = get_available_devices()
  416. if available_devices == {"cpu"} and not is_ipex_available():
  417. from importlib.util import find_spec
  418. if find_spec("intel_extension_for_pytorch"):
  419. logger.warning(
  420. "You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible."
  421. )
  422. available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment
  423. if not available_devices.intersection(bnb_supported_devices):
  424. if raise_exception:
  425. bnb_supported_devices_with_info = set( # noqa: C401
  426. '"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)'
  427. if device == "cpu"
  428. else device
  429. for device in bnb_supported_devices
  430. )
  431. err_msg = (
  432. f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. "
  433. "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
  434. )
  435. logger.error(err_msg)
  436. raise RuntimeError(err_msg)
  437. logger.warning("No supported devices found for bitsandbytes multi-backend.")
  438. return False
  439. logger.debug("Multi-backend validation successful.")
  440. return True
  441. def _validate_bnb_cuda_backend_availability(raise_exception):
  442. if not is_torch_available():
  443. return False
  444. import torch
  445. if not torch.cuda.is_available():
  446. log_msg = (
  447. "CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. "
  448. "Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
  449. )
  450. if raise_exception:
  451. logger.error(log_msg)
  452. raise RuntimeError(log_msg)
  453. logger.warning(log_msg)
  454. return False
  455. logger.debug("CUDA backend validation successful.")
  456. return True
  457. def validate_bnb_backend_availability(raise_exception=False):
  458. """
  459. Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
  460. """
  461. if not is_bitsandbytes_available():
  462. if importlib.util.find_spec("bitsandbytes") and version.parse(
  463. importlib.metadata.version("bitsandbytes")
  464. ) < version.parse("0.43.1"):
  465. return _validate_bnb_cuda_backend_availability(raise_exception)
  466. return False
  467. if is_bitsandbytes_multi_backend_available():
  468. return _validate_bnb_multi_backend_availability(raise_exception)
  469. return _validate_bnb_cuda_backend_availability(raise_exception)