| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- # coding=utf-8
- # Copyright 2022 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Convert Jukebox checkpoints"""
- import argparse
- import json
- import os
- from pathlib import Path
- import requests
- import torch
- from transformers import JukeboxConfig, JukeboxModel
- from transformers.utils import logging
- logging.set_verbosity_info()
- logger = logging.get_logger(__name__)
- PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"
- MODEL_MAPPING = {
- "jukebox-1b-lyrics": [
- "5b/vqvae.pth.tar",
- "5b/prior_level_0.pth.tar",
- "5b/prior_level_1.pth.tar",
- "1b_lyrics/prior_level_2.pth.tar",
- ],
- "jukebox-5b-lyrics": [
- "5b/vqvae.pth.tar",
- "5b/prior_level_0.pth.tar",
- "5b/prior_level_1.pth.tar",
- "5b_lyrics/prior_level_2.pth.tar",
- ],
- }
- def replace_key(key):
- if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
- key = key.replace(".model.1.bias", ".conv1d_1.bias")
- elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
- key = key.replace(".model.1.weight", ".conv1d_1.weight")
- elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
- key = key.replace(".model.3.bias", ".conv1d_2.bias")
- elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
- key = key.replace(".model.3.weight", ".conv1d_2.weight")
- if "conditioner_blocks.0." in key:
- key = key.replace("conditioner_blocks.0", "conditioner_blocks")
- if "prime_prior" in key:
- key = key.replace("prime_prior", "encoder")
- if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
- key = key.replace(".emb.", ".")
- if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook
- return key.replace(".k", ".codebook")
- if "y_emb." in key:
- return key.replace("y_emb.", "metadata_embedding.")
- if "x_emb.emb." in key:
- key = key.replace("0.x_emb.emb", "embed_tokens")
- if "prime_state_ln" in key:
- return key.replace("prime_state_ln", "encoder.final_layer_norm")
- if ".ln" in key:
- return key.replace(".ln", ".layer_norm")
- if "_ln" in key:
- return key.replace("_ln", "_layer_norm")
- if "prime_state_proj" in key:
- return key.replace("prime_state_proj", "encoder.proj_in")
- if "prime_x_out" in key:
- return key.replace("prime_x_out", "encoder.lm_head")
- if "prior.x_out" in key:
- return key.replace("x_out", "fc_proj_out")
- if "x_emb" in key:
- return key.replace("x_emb", "embed_tokens")
- return key
- def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
- new_dict = {}
- import re
- re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
- re_encoder_block_resnet = re.compile(
- r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
- )
- re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
- re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
- re_decoder_block_resnet = re.compile(
- r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
- )
- re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
- re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
- re_prior_cond_resnet = re.compile(
- r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
- )
- re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")
- for original_key, value in state_dict.items():
- # rename vqvae.encoder keys
- if re_encoder_block_conv_in.fullmatch(original_key):
- regex_match = re_encoder_block_conv_in.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[2]) * 2 + int(groups[3])
- re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}"
- key = re_encoder_block_conv_in.sub(re_new_key, original_key)
- elif re_encoder_block_resnet.fullmatch(original_key):
- regex_match = re_encoder_block_resnet.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[2]) * 2 + int(groups[3])
- conv_index = {"1": 1, "3": 2}[groups[-2]]
- prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}."
- resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
- re_new_key = prefix + resnet_block
- key = re_encoder_block_resnet.sub(re_new_key, original_key)
- elif re_encoder_block_proj_out.fullmatch(original_key):
- regex_match = re_encoder_block_proj_out.match(original_key)
- groups = regex_match.groups()
- re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}"
- key = re_encoder_block_proj_out.sub(re_new_key, original_key)
- # rename vqvae.decoder keys
- elif re_decoder_block_conv_out.fullmatch(original_key):
- regex_match = re_decoder_block_conv_out.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[2]) * 2 + int(groups[3]) - 2
- re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}"
- key = re_decoder_block_conv_out.sub(re_new_key, original_key)
- elif re_decoder_block_resnet.fullmatch(original_key):
- regex_match = re_decoder_block_resnet.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[2]) * 2 + int(groups[3]) - 2
- conv_index = {"1": 1, "3": 2}[groups[-2]]
- prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}."
- resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
- re_new_key = prefix + resnet_block
- key = re_decoder_block_resnet.sub(re_new_key, original_key)
- elif re_decoder_block_proj_in.fullmatch(original_key):
- regex_match = re_decoder_block_proj_in.match(original_key)
- groups = regex_match.groups()
- re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}"
- key = re_decoder_block_proj_in.sub(re_new_key, original_key)
- # rename prior cond.model to upsampler.upsample_block and resnet
- elif re_prior_cond_conv_out.fullmatch(original_key):
- regex_match = re_prior_cond_conv_out.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[1]) * 2 + int(groups[2]) - 2
- re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}"
- key = re_prior_cond_conv_out.sub(re_new_key, original_key)
- elif re_prior_cond_resnet.fullmatch(original_key):
- regex_match = re_prior_cond_resnet.match(original_key)
- groups = regex_match.groups()
- block_index = int(groups[1]) * 2 + int(groups[2]) - 2
- conv_index = {"1": 1, "3": 2}[groups[-2]]
- prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}."
- resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
- re_new_key = prefix + resnet_block
- key = re_prior_cond_resnet.sub(re_new_key, original_key)
- elif re_prior_cond_proj_in.fullmatch(original_key):
- regex_match = re_prior_cond_proj_in.match(original_key)
- groups = regex_match.groups()
- re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}"
- key = re_prior_cond_proj_in.sub(re_new_key, original_key)
- # keep original key
- else:
- key = original_key
- key = replace_key(key)
- if f"{key_prefix}.{key}" not in model_state_dict or key is None:
- print(f"failed converting {original_key} to {key}, does not match")
- # handle missmatched shape
- elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape:
- val = model_state_dict[f"{key_prefix}.{key}"]
- print(f"{original_key}-> {key} : \nshape {val.shape} and { value.shape}, do not match")
- key = original_key
- mapping[key] = original_key
- new_dict[key] = value
- return new_dict
- @torch.no_grad()
- def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
- """
- Copy/paste/tweak model's weights to our Jukebox structure.
- """
- for file in MODEL_MAPPING[model_name]:
- if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
- r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
- os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
- open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)
- model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
- config = JukeboxConfig.from_pretrained(model_name)
- model = JukeboxModel(config)
- weight_dict = []
- mapping = {}
- for i, dict_name in enumerate(model_to_convert):
- old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}")["model"]
- new_dic = {}
- for k in old_dic.keys():
- if k.endswith(".b"):
- new_dic[k.replace("b", "bias")] = old_dic[k]
- elif k.endswith(".w"):
- new_dic[k.replace("w", "weight")] = old_dic[k]
- elif "level_2" not in dict_name and "cond.model." in k:
- new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
- else:
- new_dic[k] = old_dic[k]
- key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
- new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
- weight_dict.append(new_dic)
- vqvae_state_dict = weight_dict.pop(0)
- model.vqvae.load_state_dict(vqvae_state_dict)
- for i in range(len(weight_dict)):
- model.priors[i].load_state_dict(weight_dict[2 - i])
- Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
- with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
- json.dump(mapping, txtfile)
- print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
- model.save_pretrained(pytorch_dump_folder_path)
- return weight_dict
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- # Required parameters
- parser.add_argument(
- "--model_name",
- default="jukebox-5b-lyrics",
- type=str,
- help="Name of the model you'd like to convert.",
- )
- parser.add_argument(
- "--pytorch_dump_folder_path",
- default="jukebox-5b-lyrics-converted",
- type=str,
- help="Path to the output PyTorch model directory.",
- )
- args = parser.parse_args()
- convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|