aqlm.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "AQLM (Additive Quantization of Language Model) integration file"
  15. from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_aqlm_available, is_torch_available
  16. if is_torch_available():
  17. import torch.nn as nn
  18. def replace_with_aqlm_linear(
  19. model,
  20. quantization_config=None,
  21. linear_weights_not_to_quantize=None,
  22. current_key_name=None,
  23. has_been_replaced=False,
  24. ):
  25. """
  26. Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
  27. `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
  28. conversion has been successfull or not.
  29. Args:
  30. model (`torch.nn.Module`):
  31. The model to convert, can be any `torch.nn.Module` instance.
  32. quantization_config (`AqlmConfig`):
  33. The quantization config object that contains the quantization parameters.
  34. linear_weights_not_to_quantize (`list[str]`, *optional*):
  35. A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
  36. converted.
  37. current_key_name (`list`, *optional*):
  38. A list that contains the current key name. This is used for recursion and should not be passed by the user.
  39. has_been_replaced (`bool`, *optional*):
  40. A boolean that indicates if the conversion has been successful or not. This is used for recursion and
  41. should not be passed by the user.
  42. """
  43. if not is_aqlm_available():
  44. raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
  45. if not is_accelerate_available():
  46. raise ValueError(
  47. f"AQLM requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  48. )
  49. if linear_weights_not_to_quantize is None:
  50. linear_weights_not_to_quantize = []
  51. from accelerate import init_empty_weights
  52. from aqlm import QuantizedLinear
  53. for name, module in model.named_children():
  54. if current_key_name is None:
  55. current_key_name = []
  56. current_key_name.append(name)
  57. if isinstance(module, nn.Linear):
  58. # Check if the current key is not in the `linear_weights_not_to_quantize`
  59. if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
  60. with init_empty_weights():
  61. in_features = module.in_features
  62. out_features = module.out_features
  63. model._modules[name] = QuantizedLinear(
  64. in_features,
  65. out_features,
  66. bias=module.bias is not None,
  67. in_group_size=quantization_config.in_group_size,
  68. out_group_size=quantization_config.out_group_size,
  69. num_codebooks=quantization_config.num_codebooks,
  70. nbits_per_codebook=quantization_config.nbits_per_codebook,
  71. )
  72. has_been_replaced = True
  73. # Store the module class in case we need to transpose the weight later
  74. model._modules[name].source_cls = type(module)
  75. # Force requires grad to False to avoid unexpected errors
  76. model._modules[name].requires_grad_(False)
  77. if len(list(module.children())) > 0:
  78. _, has_been_replaced = replace_with_aqlm_linear(
  79. module,
  80. quantization_config=quantization_config,
  81. linear_weights_not_to_quantize=linear_weights_not_to_quantize,
  82. current_key_name=current_key_name,
  83. has_been_replaced=has_been_replaced,
  84. )
  85. # Remove the last key for recursion
  86. current_key_name.pop(-1)
  87. return model, has_been_replaced