convert_dac_checkpoint.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # coding=utf-8
  2. # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import fnmatch
  17. import re
  18. import torch
  19. from transformers import (
  20. DacConfig,
  21. DacFeatureExtractor,
  22. DacModel,
  23. logging,
  24. )
  25. # checkpoints downloaded using:
  26. # pip install descript-audio-codec
  27. # python3 -m dac download # downloads the default 44kHz variant
  28. # python3 -m dac download --model_type 44khz # downloads the 44kHz variant
  29. # python3 -m dac download --model_type 24khz # downloads the 24kHz variant
  30. # python3 -m dac download --model_type 16khz # downloads the 16kHz variant
  31. # More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
  32. logging.set_verbosity_info()
  33. logger = logging.get_logger("transformers.models.dac")
  34. def match_pattern(string, pattern):
  35. # Split the pattern into parts
  36. pattern_parts = pattern.split(".")
  37. string_parts = string.split(".")
  38. pattern_block_count = string_block_count = 0
  39. for part in pattern_parts:
  40. if part.startswith("block"):
  41. pattern_block_count += 1
  42. for part in string_parts:
  43. if part.startswith("block"):
  44. string_block_count += 1
  45. return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
  46. TOP_LEVEL_KEYS = []
  47. IGNORE_KEYS = []
  48. MAPPING_ENCODER = {
  49. "encoder.block.0": ["encoder.conv1"],
  50. "encoder.block.5": ["encoder.snake1"],
  51. "encoder.block.6": ["encoder.conv2"],
  52. "encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
  53. "encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
  54. "encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
  55. "encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
  56. "encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
  57. "encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
  58. }
  59. MAPPING_QUANTIZER = {
  60. "quantizer.quantizers.*": ["quantizer.quantizers.*"],
  61. }
  62. MAPPING_DECODER = {
  63. "decoder.model.0": ["decoder.conv1"],
  64. "decoder.model.5": ["decoder.snake1"],
  65. "decoder.model.6": ["decoder.conv2"],
  66. "decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
  67. "decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
  68. "decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
  69. "decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
  70. "decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
  71. "decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
  72. }
  73. MAPPING = {
  74. **MAPPING_ENCODER,
  75. **MAPPING_QUANTIZER,
  76. **MAPPING_DECODER,
  77. }
  78. def set_recursively(hf_pointer, key, value, full_name, weight_type):
  79. for attribute in key.split("."):
  80. hf_pointer = getattr(hf_pointer, attribute)
  81. if weight_type is not None:
  82. hf_shape = getattr(hf_pointer, weight_type).shape
  83. else:
  84. hf_shape = hf_pointer.shape
  85. if hf_shape != value.shape:
  86. raise ValueError(
  87. f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
  88. f" {value.shape} for {full_name}"
  89. )
  90. if weight_type == "weight":
  91. hf_pointer.weight.data = value
  92. elif weight_type == "weight_g":
  93. hf_pointer.weight_g.data = value
  94. elif weight_type == "weight_v":
  95. hf_pointer.weight_v.data = value
  96. elif weight_type == "bias":
  97. hf_pointer.bias.data = value
  98. elif weight_type == "alpha":
  99. hf_pointer.alpha.data = value
  100. logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
  101. def should_ignore(name, ignore_keys):
  102. for key in ignore_keys:
  103. if key.endswith(".*"):
  104. if name.startswith(key[:-1]):
  105. return True
  106. elif ".*." in key:
  107. prefix, suffix = key.split(".*.")
  108. if prefix in name and suffix in name:
  109. return True
  110. elif key in name:
  111. return True
  112. return False
  113. def recursively_load_weights(orig_dict, hf_model, model_name):
  114. unused_weights = []
  115. if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
  116. raise ValueError(f"Unsupported model: {model_name}")
  117. for name, value in orig_dict.items():
  118. is_used = False
  119. for key, mapped_key in MAPPING.items():
  120. regex = re.compile(key)
  121. if regex.search(name):
  122. if len(mapped_key) == 1:
  123. if mapped_key[0][0] == "q":
  124. mapped_key = ".".join(name.split(".")[:-1])
  125. else:
  126. mapped_key = mapped_key[0]
  127. elif len(mapped_key) == 3:
  128. integers = re.findall(r"\b\d+\b", name)
  129. if mapped_key[0][0] == "d":
  130. mapped_key = "{}.{}.{}{}.{}".format(
  131. mapped_key[0],
  132. str(int(integers[0]) - 1),
  133. mapped_key[1],
  134. str(int(integers[1]) - 1),
  135. mapped_key[2],
  136. )
  137. else:
  138. mapped_key = "{}.{}.{}{}.{}".format(
  139. mapped_key[0],
  140. str(int(integers[0]) - 1),
  141. mapped_key[1],
  142. str(int(integers[1]) + 1),
  143. mapped_key[2],
  144. )
  145. elif len(mapped_key) == 2:
  146. integers = re.findall(r"\b\d+\b", name)
  147. mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
  148. is_used = True
  149. if "weight_g" in name:
  150. weight_type = "weight_g"
  151. elif "weight_v" in name:
  152. weight_type = "weight_v"
  153. elif "bias" in name:
  154. weight_type = "bias"
  155. elif "alpha" in name:
  156. weight_type = "alpha"
  157. elif "weight" in name:
  158. weight_type = "weight"
  159. set_recursively(hf_model, mapped_key, value, name, weight_type)
  160. if not is_used:
  161. unused_weights.append(name)
  162. print(list(set(unused_weights)))
  163. logger.warning(f"Unused weights: {unused_weights}")
  164. @torch.no_grad()
  165. def convert_checkpoint(
  166. model_name,
  167. checkpoint_path,
  168. pytorch_dump_folder_path,
  169. sample_rate=16000,
  170. repo_id=None,
  171. ):
  172. model_dict = torch.load(checkpoint_path, "cpu")
  173. config = DacConfig()
  174. metadata = model_dict["metadata"]["kwargs"]
  175. config.encoder_hidden_size = metadata["encoder_dim"]
  176. config.downsampling_ratios = metadata["encoder_rates"]
  177. config.codebook_size = metadata["codebook_size"]
  178. config.n_codebooks = metadata["n_codebooks"]
  179. config.codebook_dim = metadata["codebook_dim"]
  180. config.decoder_hidden_size = metadata["decoder_dim"]
  181. config.upsampling_ratios = metadata["decoder_rates"]
  182. config.quantizer_dropout = float(metadata["quantizer_dropout"])
  183. config.sampling_rate = sample_rate
  184. model = DacModel(config)
  185. feature_extractor = DacFeatureExtractor()
  186. feature_extractor.sampling_rate = sample_rate
  187. original_checkpoint = model_dict["state_dict"]
  188. model.apply_weight_norm()
  189. recursively_load_weights(original_checkpoint, model, model_name)
  190. model.remove_weight_norm()
  191. model.save_pretrained(pytorch_dump_folder_path)
  192. if repo_id:
  193. print("Pushing to the hub...")
  194. feature_extractor.push_to_hub(repo_id)
  195. model.push_to_hub(repo_id)
  196. if __name__ == "__main__":
  197. parser = argparse.ArgumentParser()
  198. parser.add_argument(
  199. "--model",
  200. default="dac_44khz",
  201. type=str,
  202. help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
  203. )
  204. parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
  205. parser.add_argument(
  206. "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
  207. )
  208. parser.add_argument(
  209. "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
  210. )
  211. parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
  212. args = parser.parse_args()
  213. convert_checkpoint(
  214. args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
  215. )