awq.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. "AWQ (Activation aware Weight Quantization) integration file"
  15. import importlib
  16. from packaging import version
  17. from ..activations import ACT2FN
  18. from ..modeling_utils import PreTrainedModel
  19. from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
  20. from ..utils.quantization_config import (
  21. AwqBackendPackingMethod,
  22. AwqConfig,
  23. AWQLinearVersion,
  24. ExllamaVersion,
  25. )
  26. if is_torch_available():
  27. import torch
  28. import torch.nn as nn
  29. logger = logging.get_logger(__name__)
  30. AWQ_FUSED_MAPPINGS = {
  31. "mistral": {
  32. "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
  33. "mlp": ["gate_proj", "up_proj", "down_proj"],
  34. "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
  35. "use_alibi": False,
  36. },
  37. "mixtral": {
  38. "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
  39. "mlp": ["w1", "w3", "w2"],
  40. "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
  41. "use_alibi": False,
  42. "rope_theta": 1000000.0,
  43. },
  44. "llama": {
  45. "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
  46. "mlp": ["gate_proj", "up_proj", "down_proj"],
  47. "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
  48. "use_alibi": False,
  49. },
  50. "llava": {
  51. "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
  52. "mlp": ["gate_proj", "up_proj", "down_proj"],
  53. "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
  54. "use_alibi": False,
  55. },
  56. }
  57. AWQ_SCALES_MAPPINGS = {
  58. "starcoder2": {"act": "act", "layer_before_act": "c_fc"},
  59. "RefinedWebModel": {"act": "act", "layer_before_act": "dense_h_to_4h"},
  60. "falcon": {"act": "act", "layer_before_act": "dense_h_to_4h"},
  61. "mpt": {"act": "act", "layer_before_act": "up_proj"},
  62. "gptj": {"act": "act", "layer_before_act": "fc_in"},
  63. "gpt_neox": {"act": "act", "layer_before_act": "dense_h_to_4h"},
  64. "gpt_bigcode": {"act": "act", "layer_before_act": "c_fc"},
  65. "bloom": {"act": "gelu_impl", "layer_before_act": "dense_h_to_4h"},
  66. }
  67. def replace_quantization_scales(model, model_type):
  68. from awq.modules.act import ScaledActivation
  69. if model_type not in AWQ_SCALES_MAPPINGS:
  70. return model
  71. for name, module in model.named_children():
  72. act_name = AWQ_SCALES_MAPPINGS[model_type]["act"]
  73. layer_before_act_name = AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"]
  74. if name == act_name and hasattr(model, layer_before_act_name):
  75. layer_before_act = getattr(model, AWQ_SCALES_MAPPINGS[model_type]["layer_before_act"])
  76. size = layer_before_act.out_features
  77. scale_like = torch.ones(size)
  78. model._modules[name] = ScaledActivation(module, scale_like)
  79. _ = replace_quantization_scales(module, model_type)
  80. return model
  81. def replace_with_awq_linear(
  82. model,
  83. modules_to_not_convert=None,
  84. quantization_config=None,
  85. current_key_name=None,
  86. has_been_replaced=False,
  87. ) -> bool:
  88. """
  89. Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
  90. `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
  91. conversion has been successfull or not.
  92. During the module replacement, we also infer the backend to use through the `quantization_config` object.
  93. Args:
  94. model (`torch.nn.Module`):
  95. The model to convert, can be any `torch.nn.Module` instance.
  96. quantization_config (`AwqConfig`):
  97. The quantization config object that contains the quantization parameters.
  98. modules_to_not_convert (`list`, *optional*):
  99. A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
  100. converted.
  101. current_key_name (`list`, *optional*):
  102. A list that contains the current key name. This is used for recursion and should not be passed by the user.
  103. has_been_replaced (`bool`, *optional*):
  104. A boolean that indicates if the conversion has been successful or not. This is used for recursion and
  105. should not be passed by the user.
  106. """
  107. if modules_to_not_convert is None:
  108. modules_to_not_convert = []
  109. backend = quantization_config.backend
  110. if not is_auto_awq_available():
  111. raise ValueError(
  112. "AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
  113. )
  114. if backend == AwqBackendPackingMethod.AUTOAWQ:
  115. if quantization_config.version == AWQLinearVersion.GEMM:
  116. from awq.modules.linear.gemm import WQLinear_GEMM
  117. target_cls = WQLinear_GEMM
  118. elif quantization_config.version == AWQLinearVersion.GEMV:
  119. from awq.modules.linear.gemv import WQLinear_GEMV
  120. target_cls = WQLinear_GEMV
  121. elif quantization_config.version == AWQLinearVersion.EXLLAMA:
  122. if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
  123. from awq.modules.linear.exllama import WQLinear_Exllama
  124. target_cls = WQLinear_Exllama
  125. elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
  126. from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
  127. target_cls = WQLinear_ExllamaV2
  128. else:
  129. raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
  130. elif quantization_config.version == AWQLinearVersion.IPEX:
  131. from awq.modules.linear.gemm_ipex import WQLinear_IPEX
  132. target_cls = WQLinear_IPEX
  133. else:
  134. raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
  135. else:
  136. from awq.quantize.qmodule import WQLinear
  137. target_cls = WQLinear
  138. for name, module in model.named_children():
  139. if current_key_name is None:
  140. current_key_name = []
  141. current_key_name.append(name)
  142. if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
  143. # Check if the current key is not in the `modules_to_not_convert`
  144. if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
  145. in_features = module.in_features
  146. out_features = module.out_features
  147. model._modules[name] = target_cls(
  148. w_bit=quantization_config.bits,
  149. group_size=quantization_config.group_size,
  150. in_features=in_features,
  151. out_features=out_features,
  152. bias=module.bias is not None,
  153. dev=module.weight.device,
  154. )
  155. has_been_replaced = True
  156. # Force requires grad to False to avoid unexpected errors
  157. model._modules[name].requires_grad_(False)
  158. if len(list(module.children())) > 0:
  159. _, has_been_replaced = replace_with_awq_linear(
  160. module,
  161. modules_to_not_convert=modules_to_not_convert,
  162. current_key_name=current_key_name,
  163. quantization_config=quantization_config,
  164. has_been_replaced=has_been_replaced,
  165. )
  166. # Remove the last key for recursion
  167. current_key_name.pop(-1)
  168. return model, has_been_replaced
  169. def get_modules_to_fuse(model, quantization_config):
  170. """
  171. Returns the fusing mapping given the quantization config and the model
  172. Args:
  173. model (`~PreTrainedModel`):
  174. The model to fuse - note this model should have been converted into AWQ format beforehand.
  175. quantization_config (`~transformers.quantization_config.AWQConfig`):
  176. The quantization configuration to use.
  177. """
  178. if not isinstance(model, PreTrainedModel):
  179. raise TypeError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")
  180. # Always default to `quantization_config.modules_to_fuse`
  181. if quantization_config.modules_to_fuse is not None:
  182. current_fused_mapping = quantization_config.modules_to_fuse
  183. current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
  184. elif model.config.model_type in AWQ_FUSED_MAPPINGS:
  185. current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]
  186. # Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
  187. config = model.config.get_text_config(decoder=True)
  188. # Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
  189. hidden_size = config.hidden_size
  190. num_attention_heads = config.num_attention_heads
  191. num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
  192. # Fill `current_fused_mapping` with the expected values
  193. current_fused_mapping["hidden_size"] = hidden_size
  194. current_fused_mapping["num_attention_heads"] = num_attention_heads
  195. current_fused_mapping["num_key_value_heads"] = num_key_value_heads
  196. current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
  197. else:
  198. raise ValueError(
  199. "Fusing mapping not found either on the quantization config or the supported `AWQ_FUSED_MAPPINGS`. Please pass a `fused_mapping` argument"
  200. " in the `quantization_config` or raise an issue on transformers https://github.com/huggingface/transformers to add its support."
  201. )
  202. return current_fused_mapping
  203. def fuse_awq_modules(model, quantization_config):
  204. """
  205. Optionally fuse some modules in the model to speedup inference.
  206. Args:
  207. model (`~PreTrainedModel`):
  208. The model to fuse - note this model should have been converted into AWQ format beforehand.
  209. quantization_config (`Union[AwqConfig, dict]`):
  210. The quantization configuration to use.
  211. """
  212. # We need to convert it from dict in order to get an AwqConfig object
  213. # otherwise the fields `backend` etc. will not be available
  214. # https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
  215. if isinstance(quantization_config, dict):
  216. quantization_config = AwqConfig.from_dict(quantization_config)
  217. backend = quantization_config.backend
  218. modules_to_fuse = get_modules_to_fuse(model, quantization_config)
  219. modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)
  220. if backend == AwqBackendPackingMethod.AUTOAWQ:
  221. from awq.modules.fused.attn import QuantAttentionFused
  222. from awq.modules.fused.mlp import QuantFusedMLP
  223. from awq.modules.fused.norm import FasterTransformerRMSNorm
  224. else:
  225. raise ValueError("Fusing is only supported for the AutoAWQ backend")
  226. fused_attention_modules = []
  227. for name, module in model.named_modules():
  228. if modules_to_not_convert is not None:
  229. if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
  230. continue
  231. # Replace layer norms
  232. _fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)
  233. # Replace MLP layers if awq version is not ipex.
  234. if quantization_config.version != "ipex":
  235. _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)
  236. else:
  237. logger.info("The IPEX version AWQ does not support fuse mlp for now.")
  238. # Replace attention layers
  239. attention_has_been_fused = _fuse_awq_attention_layers(
  240. model, module, modules_to_fuse, name, QuantAttentionFused
  241. )
  242. if attention_has_been_fused:
  243. fused_attention_modules.append(name.split(".")[0])
  244. # For AWQ fused + Llama we need to set `config._attn_implementation` = "custom" to avoid unexpected behavior and pass
  245. # `None` attention mask to the fused attention modules as now the attention mask is dropped by our models and dealt
  246. # by the `AttentionMaskConverter` module.
  247. if len(fused_attention_modules) > 0:
  248. for module_name, module in model.named_modules():
  249. if any(
  250. module_name in fused_attention_modules for fused_attention_parent_module in fused_attention_modules
  251. ):
  252. if hasattr(module, "config") and hasattr(module.config, "_attn_implementation"):
  253. module.config._attn_implementation = "custom"
  254. return model
  255. def _fuse_awq_layernorm(fuse_module_names, module, target_cls):
  256. """
  257. Fuse the LayerNorm layers into a target class using autoawq
  258. Args:
  259. fuse_module_names (`List[str]`):
  260. The list of module names to fuse
  261. module (`nn.Module`):
  262. The pytorch parent module that has layernorm modules to fuse
  263. target_cls (`~autoawq.FasterTransformerRMSNorm`):
  264. The `FasterTransformerRMSNorm` class as it only supports that class
  265. for now.
  266. """
  267. for module_name in fuse_module_names:
  268. if hasattr(module, module_name):
  269. old_module = getattr(module, module_name)
  270. module._modules[module_name] = target_cls(
  271. old_module.weight,
  272. old_module.variance_epsilon,
  273. ).to(old_module.weight.device)
  274. del old_module
  275. def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_cls):
  276. """
  277. Fuse the MLP layers into a target class using autoawq
  278. Args:
  279. model (`~PreTrainedModel`):
  280. The input pretrained model
  281. current_module_name (`str`):
  282. The current submodule name
  283. fuse_module_names (`List[str]`):
  284. The list of module names to fuse. For the MLP layers it has to be an array
  285. of length 3 that consists of the 3 MLP layers in the order (gate (dense layer post-attention) / up / down layers)
  286. module (`nn.Module`):
  287. The pytorch parent module that has layernorm modules to fuse
  288. target_cls (`~autoawq.QuantFusedMLP`):
  289. The `QuantFusedMLP` class as it only supports that class
  290. for now.
  291. """
  292. if len(fuse_module_names) == 0:
  293. return
  294. if hasattr(module, fuse_module_names[0]):
  295. gate_proj = getattr(module, fuse_module_names[0])
  296. up_proj = getattr(module, fuse_module_names[1])
  297. down_proj = getattr(module, fuse_module_names[2])
  298. previous_device = gate_proj.qweight.device
  299. # Deal also with the case model has `text_config` attribute
  300. config = model.config.get_text_config(decoder=True)
  301. hidden_act = config.hidden_act
  302. activation_fn = ACT2FN[hidden_act]
  303. new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)
  304. parent_name, child_name = current_module_name.rsplit(".", 1)
  305. parent = model.get_submodule(parent_name)
  306. setattr(parent, child_name, new_module.to(previous_device))
  307. del gate_proj, up_proj, down_proj
  308. def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_name, target_cls):
  309. """
  310. Fuse the Attention layers into a target class using autoawq
  311. Args:
  312. model (`~PreTrainedModel`):
  313. The input pretrained model
  314. module (`nn.Module`):
  315. The pytorch parent module that has layernorm modules to fuse
  316. modules_to_fuse (`List[str]`):
  317. The module fusing mapping. The dictionary has to contain a field `attention` with attention module names
  318. in the correct order: q, k, v, o layer
  319. current_module_name (`str`):
  320. The current submodule name
  321. target_cls (`~autoawq.QuantAttentionFused`):
  322. The `QuantAttentionFused` class as it only supports that class
  323. for now.
  324. """
  325. from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
  326. module_has_been_fused = False
  327. if len(modules_to_fuse["attention"]) == 0:
  328. return module_has_been_fused
  329. if hasattr(module, modules_to_fuse["attention"][0]):
  330. # First, we pack the QKV layers together
  331. q_proj = getattr(module, modules_to_fuse["attention"][0])
  332. if isinstance(q_proj, WQLinear_GEMV):
  333. linear_target_cls = WQLinear_GEMV
  334. cat_dim = 0
  335. elif isinstance(q_proj, WQLinear_GEMM):
  336. linear_target_cls = WQLinear_GEMM
  337. cat_dim = 1
  338. elif is_ipex_available() and version.parse(importlib.metadata.version("autoawq")) > version.parse("0.2.6"):
  339. from awq.modules.linear import WQLinear_IPEX
  340. if isinstance(q_proj, WQLinear_IPEX):
  341. linear_target_cls = WQLinear_IPEX
  342. cat_dim = 1
  343. else:
  344. raise ValueError("Unsupported q_proj type: {type(q_proj)}")
  345. previous_device = q_proj.qweight.device
  346. k_proj = getattr(module, modules_to_fuse["attention"][1])
  347. v_proj = getattr(module, modules_to_fuse["attention"][2])
  348. o_proj = getattr(module, modules_to_fuse["attention"][3])
  349. bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
  350. qkv_layer = linear_target_cls(
  351. q_proj.w_bit,
  352. q_proj.group_size,
  353. q_proj.in_features,
  354. q_proj.out_features + k_proj.out_features + v_proj.out_features,
  355. q_proj.bias is not None,
  356. next(iter(module.state_dict().values())).device,
  357. )
  358. qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=cat_dim)
  359. qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=cat_dim)
  360. qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=cat_dim)
  361. if isinstance(qkv_layer, WQLinear_GEMV):
  362. qkv_layer.split_k_iters = q_proj.split_k_iters
  363. qkv_layer.bias = bias
  364. fused_attention_layer = target_cls(
  365. modules_to_fuse["hidden_size"],
  366. modules_to_fuse["num_attention_heads"],
  367. modules_to_fuse["num_key_value_heads"],
  368. qkv_layer,
  369. o_proj,
  370. previous_device,
  371. modules_to_fuse["max_seq_len"],
  372. use_alibi=modules_to_fuse["use_alibi"],
  373. # The default value in autoawq is set to 10000.0
  374. rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
  375. )
  376. fused_attention_layer.is_hf_transformers = True
  377. parent_name, child_name = current_module_name.rsplit(".", 1)
  378. parent = model.get_submodule(parent_name)
  379. setattr(parent, child_name, fused_attention_layer.to(previous_device))
  380. del q_proj, k_proj, v_proj, o_proj
  381. module_has_been_fused = True
  382. return module_has_been_fused
  383. def post_init_awq_exllama_modules(model, exllama_config):
  384. """
  385. Runs post init for Exllama layers which performs:
  386. - Weights unpacking, reordering and repacking
  387. - Devices scratch space allocation
  388. """
  389. if exllama_config["version"] == ExllamaVersion.ONE:
  390. from awq.modules.linear.exllama import exllama_post_init
  391. model = exllama_post_init(model)
  392. elif exllama_config["version"] == ExllamaVersion.TWO:
  393. from awq.modules.linear.exllamav2 import exllamav2_post_init
  394. model = exllamav2_post_init(
  395. model,
  396. max_input_len=exllama_config["max_input_len"],
  397. max_batch_size=exllama_config["max_batch_size"],
  398. )
  399. else:
  400. raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
  401. return model
  402. def post_init_awq_ipex_modules(model):
  403. """
  404. Runs post init for IPEX layers which performs:
  405. - Weights packing, reordering and repacking
  406. """
  407. from awq.modules.linear.gemm_ipex import ipex_post_init
  408. model = ipex_post_init(model)
  409. return model