quanto.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging
  15. if is_torch_available():
  16. import torch
  17. logger = logging.get_logger(__name__)
  18. def replace_with_quanto_layers(
  19. model,
  20. quantization_config=None,
  21. modules_to_not_convert=None,
  22. current_key_name=None,
  23. has_been_replaced=False,
  24. ):
  25. """
  26. Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
  27. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
  28. Args:
  29. model (`torch.nn.Module`):
  30. The model to convert, can be any `torch.nn.Module` instance.
  31. quantization_config (`AqlmConfig`, defaults to `None`):
  32. The quantization config object that contains the quantization parameters.
  33. modules_to_not_convert (`list`, *optional*, defaults to `None`):
  34. A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
  35. converted.
  36. current_key_name (`list`, *optional*, defaults to `None`):
  37. A list that contains the current key name. This is used for recursion and should not be passed by the user.
  38. has_been_replaced (`bool`, *optional*, defaults to `None`):
  39. A boolean that indicates if the conversion has been successful or not. This is used for recursion and
  40. should not be passed by the user.
  41. """
  42. from accelerate import init_empty_weights
  43. if is_optimum_quanto_available():
  44. from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
  45. elif is_quanto_available():
  46. logger.warning_once(
  47. "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
  48. )
  49. from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
  50. w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
  51. a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
  52. if modules_to_not_convert is None:
  53. modules_to_not_convert = []
  54. for name, module in model.named_children():
  55. if current_key_name is None:
  56. current_key_name = []
  57. current_key_name.append(name)
  58. if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
  59. with init_empty_weights():
  60. if isinstance(module, torch.nn.Linear):
  61. model._modules[name] = QLinear(
  62. in_features=module.in_features,
  63. out_features=module.out_features,
  64. bias=module.bias is not None,
  65. dtype=module.weight.dtype,
  66. weights=w_mapping[quantization_config.weights],
  67. activations=a_mapping[quantization_config.activations],
  68. )
  69. model._modules[name].requires_grad_(False)
  70. has_been_replaced = True
  71. elif isinstance(module, torch.nn.LayerNorm):
  72. if quantization_config.activations is not None:
  73. model._modules[name] = QLayerNorm(
  74. module.normalized_shape,
  75. module.eps,
  76. module.elementwise_affine,
  77. module.bias is not None,
  78. activations=a_mapping[quantization_config.activations],
  79. )
  80. has_been_replaced = True
  81. if len(list(module.children())) > 0:
  82. _, has_been_replaced = replace_with_quanto_layers(
  83. module,
  84. quantization_config=quantization_config,
  85. modules_to_not_convert=modules_to_not_convert,
  86. current_key_name=current_key_name,
  87. has_been_replaced=has_been_replaced,
  88. )
  89. # Remove the last key for recursion
  90. current_key_name.pop(-1)
  91. return model, has_been_replaced