hqq.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "HQQ (Half-Quadratic Quantization) integration file"
  15. from ..utils import is_hqq_available, is_torch_available, logging
  16. if is_torch_available():
  17. import torch
  18. logger = logging.get_logger(__name__)
  19. # Name all modules inside the model
  20. def autoname_modules(model):
  21. for name, module in model.named_modules():
  22. module.name = name
  23. # Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj
  24. def name_to_linear_tag(name):
  25. return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))])
  26. # Get all linear tags available
  27. def get_linear_tags(model):
  28. if is_hqq_available():
  29. from hqq.core.quantize import HQQLinear
  30. linear_tags = set()
  31. for name, module in model.named_modules():
  32. if isinstance(module, (torch.nn.Linear, HQQLinear)):
  33. linear_tags.add(name_to_linear_tag(name))
  34. return list(linear_tags)
  35. def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None):
  36. for name, module in model.named_children():
  37. if current_key_name is None:
  38. current_key_name = []
  39. current_key_name.append(name)
  40. if isinstance(module, torch.nn.Linear):
  41. # Get linear tag
  42. linear_tag = name_to_linear_tag(module.name)
  43. # We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param()
  44. if linear_tag in patch_params:
  45. if patch_params[linear_tag] is not None:
  46. model._modules[name].quant_config = patch_params[linear_tag]
  47. # Store the module class in case we need to transpose the weight later
  48. model._modules[name].source_cls = type(module)
  49. # Force requires grad to False to avoid unexpected errors
  50. model._modules[name].requires_grad_(False)
  51. has_been_replaced = True
  52. # Add these fake parameters to avoid loading fail
  53. for att in ["W_q", "meta"]:
  54. setattr(module, att, None)
  55. if len(list(module.children())) > 0:
  56. _, has_been_replaced = _prepare_for_hqq_linear(
  57. module,
  58. patch_params=patch_params,
  59. has_been_replaced=has_been_replaced,
  60. )
  61. # Remove the last key for recursion
  62. current_key_name.pop(-1)
  63. return model, has_been_replaced
  64. def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False):
  65. """
  66. Prepares nn.Linear layers for HQQ quantization.
  67. Since each layer type can have separate quantization parameters, we need to do the following:
  68. 1- tag each module with its neme via autoname_modules()
  69. 2- Extract linear_tags (e.g. ['self_attn.q_proj', ...])
  70. 3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params
  71. """
  72. modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
  73. # Add name to module
  74. autoname_modules(model)
  75. # Get linear tags. This allows us to use different quant params to different layer types
  76. linear_tags = get_linear_tags(model)
  77. # Convert quantization_config to layer-wise config
  78. skip_modules = quantization_config.skip_modules
  79. quant_config = quantization_config.quant_config
  80. linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
  81. if any(key in linear_tags for key in quant_config.keys()):
  82. # If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
  83. patch_params = {key: None for key in linear_tags}
  84. patch_params.update(quant_config)
  85. else:
  86. # Same quant_config for all layers
  87. patch_params = {k: quant_config for k in linear_tags}
  88. model, has_been_replaced = _prepare_for_hqq_linear(
  89. model, patch_params=patch_params, has_been_replaced=has_been_replaced
  90. )
  91. # We store quantization config as linear_tag -> hqq quant config
  92. model.config.quantization_config = {
  93. "quant_config": quant_config,
  94. "quant_method": quantization_config.quant_method,
  95. "skip_modules": skip_modules,
  96. }
  97. if not has_been_replaced:
  98. logger.warning("No linear modules were found in your model for quantization.")
  99. return model