modeling_gguf_pytorch_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # coding=utf-8
  2. # Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
  3. # https://github.com/99991/pygguf
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import re
  17. from typing import Dict, Optional
  18. import numpy as np
  19. from tqdm import tqdm
  20. from .integrations import (
  21. GGUF_CONFIG_MAPPING,
  22. GGUF_TENSOR_MAPPING,
  23. GGUF_TOKENIZER_MAPPING,
  24. _gguf_parse_value,
  25. )
  26. from .utils import is_torch_available
  27. from .utils.import_utils import is_gguf_available
  28. from .utils.logging import get_logger
  29. if is_torch_available():
  30. import torch
  31. logger = get_logger(__name__)
  32. GGUF_TO_TRANSFORMERS_MAPPING = {
  33. "ignore": {
  34. "GGUF": {
  35. "version": "version",
  36. "tensor_count": "tensor_count",
  37. "kv_count": "kv_count",
  38. },
  39. "general": {"file_type": "file_type", "quantization_version": "quantization_version"},
  40. },
  41. "config": GGUF_CONFIG_MAPPING,
  42. "tensors": GGUF_TENSOR_MAPPING,
  43. "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
  44. "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
  45. }
  46. GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["tensors"].keys())
  47. def read_field(reader, field):
  48. value = reader.fields[field]
  49. return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
  50. def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
  51. """
  52. Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
  53. tokenizer and config attributes.
  54. Args:
  55. gguf_checkpoint_path (`str`):
  56. The path the to GGUF file to load
  57. return_tensors (`bool`, defaults to `True`):
  58. Whether to read the tensors from the file and return them. Not doing so is faster
  59. and only loads the metadata in memory.
  60. """
  61. if is_gguf_available() and is_torch_available():
  62. from gguf import GGUFReader, dequantize
  63. else:
  64. logger.error(
  65. "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
  66. "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
  67. )
  68. raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
  69. reader = GGUFReader(gguf_checkpoint_path)
  70. fields = reader.fields
  71. reader_keys = list(fields.keys())
  72. parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
  73. architecture = read_field(reader, "general.architecture")[0]
  74. model_name = read_field(reader, "general.name")
  75. # in llama.cpp mistral models use the same architecture as llama. We need
  76. # to add this patch to ensure things work correctly on our side.
  77. if "llama" in architecture and "mistral" in model_name:
  78. updated_architecture = "mistral"
  79. else:
  80. updated_architecture = architecture
  81. if "qwen2moe" in architecture:
  82. updated_architecture = "qwen2_moe"
  83. model_size = ""
  84. # extract the number of params from file name as architectures can differ ;
  85. # eg. for falcon : `...falcon-7b-...`
  86. if "falcon" in architecture:
  87. gguf_file_name = gguf_checkpoint_path.split("/")[-1].lower()
  88. m = re.search(r"-\d+b-", gguf_file_name) # regex to catch `-7b-`
  89. if m is None:
  90. raise ValueError(
  91. f"From file name, cannot determine the number of parameters for {architecture} architecture"
  92. )
  93. model_size = m.group().strip("-") # only keeps `7b`
  94. if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
  95. raise ValueError(f"Architecture {architecture + model_size} not supported")
  96. # List all key-value pairs in a columnized format
  97. for gguf_key, field in reader.fields.items():
  98. gguf_key = gguf_key.replace(architecture, updated_architecture)
  99. split = gguf_key.split(".")
  100. prefix = split[0]
  101. config_key = ".".join(split[1:])
  102. value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
  103. if len(value) == 1:
  104. value = value[0]
  105. if isinstance(value, str) and architecture in value:
  106. value = value.replace(architecture, updated_architecture)
  107. for parameter in GGUF_TO_TRANSFORMERS_MAPPING:
  108. parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter]
  109. if prefix in parameter_renames and config_key in parameter_renames[prefix]:
  110. renamed_config_key = parameter_renames[prefix][config_key]
  111. if renamed_config_key == -1:
  112. continue
  113. if renamed_config_key is not None:
  114. parsed_parameters[parameter][renamed_config_key] = value
  115. if gguf_key in reader_keys:
  116. reader_keys.remove(gguf_key)
  117. if gguf_key in reader_keys:
  118. logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
  119. # retrieve config vocab_size from tokenizer
  120. # Pleas refer to https://github.com/huggingface/transformers/issues/32526 for more details
  121. if "vocab_size" not in parsed_parameters["config"]:
  122. tokenizer_parameters = parsed_parameters["tokenizer"]
  123. if "tokens" in tokenizer_parameters:
  124. parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"])
  125. else:
  126. logger.warning(
  127. "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. "
  128. "This will use default value from model config class and cause unexpected behavior."
  129. )
  130. if return_tensors:
  131. tensor_key_mapping = GGUF_TO_TRANSFORMERS_MAPPING["tensors"][architecture + model_size]
  132. for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
  133. name = tensor.name
  134. weights = dequantize(tensor.data, tensor.tensor_type)
  135. if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
  136. num_heads = parsed_parameters["config"]["num_attention_heads"]
  137. num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
  138. if ".attn_q." in name:
  139. weights = reverse_permute_weights(weights, num_heads, num_heads)
  140. elif ".attn_k." in name:
  141. weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
  142. if architecture == "qwen2moe":
  143. if "_exp" in name:
  144. split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
  145. continue
  146. if "ffn_gate_inp_shexp" in name:
  147. # for compatibility tensor shared_expert_gate must be (1, 2048) dim,
  148. # quantized one is (2048)
  149. weights = np.expand_dims(weights, axis=0)
  150. if architecture == "bloom" and "attn_qkv" in name:
  151. num_heads = parsed_parameters["config"]["n_head"]
  152. n_embed = parsed_parameters["config"]["hidden_size"]
  153. if "weight" in name:
  154. weights = reverse_reshape_weights(weights, num_heads, n_embed)
  155. else:
  156. weights = reverse_reshape_bias(weights, num_heads, n_embed)
  157. if architecture == "gpt2":
  158. if (
  159. "attn_qkv.weight" in name
  160. or "ffn_down.weight" in name
  161. or "ffn_up.weight" in name
  162. or "attn_output.weight" in name
  163. ):
  164. # Original transpose implementation
  165. # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
  166. weights = weights.T
  167. if name == "output.weight":
  168. # output.weight has conflicts with attn_output.weight in name checking
  169. # we have to explicitly check that name is exactly output.weight
  170. name = "lm_head.weight"
  171. parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
  172. continue
  173. for tensor_name in tensor_key_mapping:
  174. if tensor_name in name:
  175. name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
  176. # Use copy to avoid errors with numpy and pytorch
  177. parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
  178. if len(reader_keys) > 0:
  179. logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
  180. return parsed_parameters
  181. def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
  182. # Original permutation implementation
  183. # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
  184. if num_kv_heads is not None and n_head != num_kv_heads:
  185. n_head = num_kv_heads
  186. dim = weights.shape[0] // n_head // 2
  187. w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
  188. return w.swapaxes(2, 1).reshape(weights.shape)
  189. def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
  190. # Original reshape implementation
  191. # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
  192. q, k, v = np.array_split(weights, 3, axis=0)
  193. q = q.reshape(n_head, n_embed // n_head, n_embed)
  194. k = k.reshape(n_head, n_embed // n_head, n_embed)
  195. v = v.reshape(n_head, n_embed // n_head, n_embed)
  196. qkv_weights = np.stack([q, k, v], axis=1)
  197. return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
  198. def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
  199. # Original reshape implementation
  200. # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
  201. q_bias, k_bias, v_bias = np.array_split(weights, 3)
  202. q_bias = q_bias.reshape(n_head, n_embed // n_head)
  203. k_bias = k_bias.reshape(n_head, n_embed // n_head)
  204. v_bias = v_bias.reshape(n_head, n_embed // n_head)
  205. qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
  206. return qkv_bias
  207. def split_moe_expert_tensor(
  208. weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
  209. ):
  210. # Original merge implementation
  211. # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
  212. exp_name = ""
  213. if "ffn_gate_exps" in name:
  214. exp_name = "gate_proj"
  215. elif "ffn_down_exps" in name:
  216. exp_name = "down_proj"
  217. elif "ffn_up_exps" in name:
  218. exp_name = "up_proj"
  219. else:
  220. raise ValueError(f"Cannot map expert tensor {name} in Qwen2Moe architecture.")
  221. for tensor_name in tensor_key_mapping:
  222. if tensor_name in name:
  223. name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
  224. w_counter = parsed_parameters["config"].get("num_experts", 60)
  225. for i in range(0, w_counter):
  226. temp_name = name.replace(".weight", f".{i}.{exp_name}.weight")
  227. exp_weight = weights[i]
  228. parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))