convert_esm.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Convert ESM checkpoint."""
  16. import argparse
  17. import pathlib
  18. from pathlib import Path
  19. from tempfile import TemporaryDirectory
  20. import esm as esm_module
  21. import torch
  22. from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences
  23. from esm.esmfold.v1.pretrained import esmfold_v1
  24. from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
  25. from transformers.models.esm.modeling_esm import (
  26. EsmForMaskedLM,
  27. EsmForSequenceClassification,
  28. EsmIntermediate,
  29. EsmLayer,
  30. EsmOutput,
  31. EsmSelfAttention,
  32. EsmSelfOutput,
  33. )
  34. from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
  35. from transformers.models.esm.tokenization_esm import EsmTokenizer
  36. from transformers.utils import logging
  37. logging.set_verbosity_info()
  38. logger = logging.get_logger(__name__)
  39. SAMPLE_DATA = [
  40. (
  41. "protein1",
  42. "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
  43. ),
  44. ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
  45. ("protein3", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"),
  46. ("protein4", "MKTVRQERLKSI<mask>RILERSKEPVSGAQLAEELS<mask>SRQVIVQDIAYLRSLGYN<mask>VATPRGYVLA"),
  47. ]
  48. MODEL_MAPPING = {
  49. "esm1b_t33_650M_UR50S": esm_module.pretrained.esm1b_t33_650M_UR50S,
  50. "esm1v_t33_650M_UR90S_1": esm_module.pretrained.esm1v_t33_650M_UR90S_1,
  51. "esm1v_t33_650M_UR90S_2": esm_module.pretrained.esm1v_t33_650M_UR90S_2,
  52. "esm1v_t33_650M_UR90S_3": esm_module.pretrained.esm1v_t33_650M_UR90S_3,
  53. "esm1v_t33_650M_UR90S_4": esm_module.pretrained.esm1v_t33_650M_UR90S_4,
  54. "esm1v_t33_650M_UR90S_5": esm_module.pretrained.esm1v_t33_650M_UR90S_5,
  55. "esm2_t48_15B_UR50D": esm_module.pretrained.esm2_t48_15B_UR50D,
  56. "esm2_t36_3B_UR50D": esm_module.pretrained.esm2_t36_3B_UR50D,
  57. "esm2_t33_650M_UR50D": esm_module.pretrained.esm2_t33_650M_UR50D,
  58. "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
  59. "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
  60. "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
  61. "esmfold_v1": esmfold_v1,
  62. }
  63. restypes = list("ARNDCQEGHILKMFPSTWYV")
  64. restypes_with_x = restypes + ["X"]
  65. restypes_with_extras = restypes_with_x + ["<pad>", "<mask>", "<cls>", "<sep>", "<eos>"]
  66. def get_esmfold_tokenizer():
  67. with TemporaryDirectory() as tempdir:
  68. vocab = "\n".join(restypes_with_extras)
  69. vocab_file = Path(tempdir) / "vocab.txt"
  70. vocab_file.write_text(vocab)
  71. hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
  72. hf_tokenizer.pad_token_id = 0 # Overlaps with 'A' but that seems to be what they want
  73. return hf_tokenizer
  74. def transfer_and_check_weights(original_module, our_module):
  75. status = our_module.load_state_dict(original_module.state_dict())
  76. if status.missing_keys:
  77. raise ValueError(f"Missing keys: {status.missing_keys}")
  78. if status.unexpected_keys:
  79. raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
  80. def convert_esm_checkpoint_to_pytorch(
  81. model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
  82. ):
  83. """
  84. Copy/paste/tweak esm's weights to our BERT structure.
  85. """
  86. if model.startswith("esmfold"):
  87. esm = MODEL_MAPPING[model]()
  88. else:
  89. esm, alphabet = MODEL_MAPPING[model]()
  90. esm.eval() # disable dropout
  91. if model.startswith("esmfold"):
  92. embed_dim = esm.esm.embed_dim
  93. num_layers = esm.esm.num_layers
  94. num_attention_heads = esm.esm.attention_heads
  95. intermediate_size = 4 * embed_dim
  96. token_dropout = esm.esm.token_dropout
  97. emb_layer_norm_before = False # This code path does not exist in ESM-2
  98. position_embedding_type = "rotary"
  99. is_folding_model = True
  100. esmfold_config = EsmFoldConfig()
  101. for key, val in esm.cfg.items():
  102. if hasattr(esmfold_config, key) and key != "trunk":
  103. setattr(esmfold_config, key, val)
  104. for key, val in esm.cfg.trunk.items():
  105. if hasattr(esmfold_config.trunk, key) and key != "structure_module":
  106. setattr(esmfold_config.trunk, key, val)
  107. for key, val in esm.cfg.trunk.structure_module.items():
  108. if hasattr(esmfold_config.trunk.structure_module, key):
  109. setattr(esmfold_config.trunk.structure_module, key, val)
  110. elif hasattr(esm, "args"):
  111. # Indicates an ESM-1b or ESM-1v model
  112. embed_dim = esm.args.embed_dim
  113. num_layers = esm.args.layers
  114. num_attention_heads = esm.args.attention_heads
  115. intermediate_size = esm.args.ffn_embed_dim
  116. token_dropout = esm.args.token_dropout
  117. emb_layer_norm_before = True if esm.emb_layer_norm_before else False
  118. position_embedding_type = "absolute"
  119. is_folding_model = False
  120. esmfold_config = None
  121. else:
  122. # Indicates an ESM-2 model
  123. embed_dim = esm.embed_dim
  124. num_layers = esm.num_layers
  125. num_attention_heads = esm.attention_heads
  126. intermediate_size = 4 * embed_dim # This is hardcoded in ESM-2
  127. token_dropout = esm.token_dropout
  128. emb_layer_norm_before = False # This code path does not exist in ESM-2
  129. position_embedding_type = "rotary"
  130. is_folding_model = False
  131. esmfold_config = None
  132. if is_folding_model:
  133. alphabet = esm.esm.alphabet
  134. vocab_list = tuple(alphabet.all_toks)
  135. mask_token_id = alphabet.mask_idx
  136. pad_token_id = alphabet.padding_idx
  137. if is_folding_model:
  138. original_esm_model = esm.esm
  139. else:
  140. original_esm_model = esm
  141. config = EsmConfig(
  142. vocab_size=original_esm_model.embed_tokens.num_embeddings,
  143. mask_token_id=mask_token_id,
  144. hidden_size=embed_dim,
  145. num_hidden_layers=num_layers,
  146. num_attention_heads=num_attention_heads,
  147. intermediate_size=intermediate_size,
  148. max_position_embeddings=1026,
  149. layer_norm_eps=1e-5, # PyTorch default used in fairseq
  150. attention_probs_dropout_prob=0.0,
  151. hidden_dropout_prob=0.0,
  152. pad_token_id=pad_token_id,
  153. emb_layer_norm_before=emb_layer_norm_before,
  154. token_dropout=token_dropout,
  155. position_embedding_type=position_embedding_type,
  156. is_folding_model=is_folding_model,
  157. esmfold_config=esmfold_config,
  158. vocab_list=vocab_list,
  159. )
  160. if classification_head:
  161. config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
  162. print("Our ESM config:", config)
  163. if model.startswith("esmfold"):
  164. model_class = EsmForProteinFolding
  165. elif classification_head:
  166. model_class = EsmForSequenceClassification
  167. else:
  168. model_class = EsmForMaskedLM
  169. model = model_class(config)
  170. model.eval()
  171. # Now let's copy all the weights.
  172. # Embeddings
  173. model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
  174. if position_embedding_type == "absolute":
  175. model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
  176. if config.emb_layer_norm_before:
  177. model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
  178. model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
  179. model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
  180. model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
  181. for i in range(config.num_hidden_layers):
  182. # Encoder: start of layer
  183. layer: EsmLayer = model.esm.encoder.layer[i]
  184. # esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i]
  185. esm_layer = original_esm_model.layers[i]
  186. # self attention
  187. self_attn: EsmSelfAttention = layer.attention.self
  188. assert (
  189. esm_layer.self_attn.k_proj.weight.data.shape
  190. == esm_layer.self_attn.q_proj.weight.data.shape
  191. == esm_layer.self_attn.v_proj.weight.data.shape
  192. == torch.Size((config.hidden_size, config.hidden_size))
  193. )
  194. self_attn.query.weight.data = esm_layer.self_attn.q_proj.weight
  195. self_attn.query.bias.data = esm_layer.self_attn.q_proj.bias
  196. self_attn.key.weight.data = esm_layer.self_attn.k_proj.weight
  197. self_attn.key.bias.data = esm_layer.self_attn.k_proj.bias
  198. self_attn.value.weight.data = esm_layer.self_attn.v_proj.weight
  199. self_attn.value.bias.data = esm_layer.self_attn.v_proj.bias
  200. if getattr(esm_layer.self_attn, "rot_emb", None) is not None:
  201. # Matt: Although inv_freq is not a trainable weight, it is computed at model init and cached.
  202. # During the training of ESM-2 the model was converted to float16 precision, which also converts
  203. # the inv_freq tensor, and the loss of precision remains even if the model is loaded later as float32.
  204. # If we recompute inv_freq without this loss of precision then we will get subtly different rotary
  205. # embeddings, which are enough to cause significant discrepancies in model outputs. To avoid this,
  206. # we make sure the new model copies the data from the old inv_freq.
  207. self_attn.rotary_embeddings.inv_freq.data = esm_layer.self_attn.rot_emb.inv_freq
  208. # LayerNorm changes for pre-activation
  209. layer.attention.LayerNorm.weight = esm_layer.self_attn_layer_norm.weight
  210. layer.attention.LayerNorm.bias = esm_layer.self_attn_layer_norm.bias
  211. layer.LayerNorm.weight = esm_layer.final_layer_norm.weight
  212. layer.LayerNorm.bias = esm_layer.final_layer_norm.bias
  213. # self-attention output
  214. self_output: EsmSelfOutput = layer.attention.output
  215. assert self_output.dense.weight.shape == esm_layer.self_attn.out_proj.weight.shape
  216. self_output.dense.weight = esm_layer.self_attn.out_proj.weight
  217. self_output.dense.bias = esm_layer.self_attn.out_proj.bias
  218. # intermediate
  219. intermediate: EsmIntermediate = layer.intermediate
  220. assert intermediate.dense.weight.shape == esm_layer.fc1.weight.shape
  221. intermediate.dense.weight = esm_layer.fc1.weight
  222. intermediate.dense.bias = esm_layer.fc1.bias
  223. # output
  224. bert_output: EsmOutput = layer.output
  225. assert bert_output.dense.weight.shape == esm_layer.fc2.weight.shape
  226. bert_output.dense.weight = esm_layer.fc2.weight
  227. bert_output.dense.bias = esm_layer.fc2.bias
  228. # end of layer
  229. if is_folding_model:
  230. model.esm_s_combine.data = esm.esm_s_combine.data
  231. model.af2_to_esm.data = esm.af2_to_esm.data
  232. transfer_and_check_weights(esm.embedding, model.embedding)
  233. transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
  234. transfer_and_check_weights(esm.trunk, model.trunk)
  235. transfer_and_check_weights(esm.distogram_head, model.distogram_head)
  236. transfer_and_check_weights(esm.ptm_head, model.ptm_head)
  237. transfer_and_check_weights(esm.lm_head, model.lm_head)
  238. transfer_and_check_weights(esm.lddt_head, model.lddt_head)
  239. elif classification_head:
  240. model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
  241. model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
  242. model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
  243. model.classifier.out_proj.bias = esm.classification_heads["mnli"].out_proj.bias
  244. else:
  245. # LM Head
  246. model.lm_head.dense.weight = esm.lm_head.dense.weight
  247. model.lm_head.dense.bias = esm.lm_head.dense.bias
  248. model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
  249. model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
  250. model.lm_head.decoder.weight = esm.lm_head.weight
  251. model.lm_head.bias = esm.lm_head.bias
  252. # Contact prediction head
  253. transfer_and_check_weights(esm.contact_head, model.esm.contact_head)
  254. # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
  255. if is_folding_model:
  256. # Folding models aren't trained on masked inputs and don't like mask tokens.
  257. sample_data = SAMPLE_DATA[:2]
  258. else:
  259. sample_data = SAMPLE_DATA
  260. if is_folding_model:
  261. hf_tokenizer = get_esmfold_tokenizer()
  262. hf_tokens = hf_tokenizer(
  263. [row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
  264. )
  265. esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
  266. success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
  267. hf_tokens["attention_mask"] == esmfold_mask
  268. )
  269. else:
  270. # Let's check that we get the same results.
  271. batch_converter = alphabet.get_batch_converter()
  272. batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
  273. # Prepare tokenizer and make sure it matches
  274. with TemporaryDirectory() as tempdir:
  275. vocab = "\n".join(alphabet.all_toks)
  276. vocab_file = Path(tempdir) / "vocab.txt"
  277. vocab_file.write_text(vocab)
  278. hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
  279. hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
  280. success = torch.all(hf_tokens["input_ids"] == batch_tokens)
  281. print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
  282. if not success:
  283. raise Exception("Tokenization does not match!")
  284. with torch.no_grad():
  285. if is_folding_model:
  286. # Let's test the model in parts
  287. # ESMFold always converts the ESM stem to float16, which requires float16 ops
  288. # that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
  289. # ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
  290. # original and the converted model on the GPU at the same time.
  291. their_output = esm.cuda().infer([row[1] for row in sample_data])
  292. our_output = model.cuda()(
  293. input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
  294. )
  295. else:
  296. our_output = model(**hf_tokens, output_hidden_states=True)
  297. our_output = our_output["logits"]
  298. if classification_head:
  299. their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
  300. else:
  301. their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
  302. their_output = their_output["logits"]
  303. if is_folding_model:
  304. max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
  305. success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
  306. else:
  307. max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
  308. success = torch.allclose(our_output, their_output, atol=1e-5)
  309. print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
  310. print("Do both models output the same tensors?", "🔥" if success else "💩")
  311. if not success:
  312. raise Exception("Something went wRoNg")
  313. if not is_folding_model:
  314. # Let's check contact prediction too
  315. our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
  316. their_output = esm.predict_contacts(hf_tokens["input_ids"])
  317. max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
  318. success = torch.allclose(our_output, their_output, atol=1e-5)
  319. print("Contact prediction testing:")
  320. print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
  321. print("Do both models output the same tensors?", "🔥" if success else "💩")
  322. if not success:
  323. raise Exception("Something went wRoNg")
  324. pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
  325. print(f"Saving model to {pytorch_dump_folder_path}")
  326. model.save_pretrained(pytorch_dump_folder_path)
  327. del esm # Free up some memory before continuing
  328. print(f"Saving tokenizer to {pytorch_dump_folder_path}")
  329. hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
  330. if push_to_repo:
  331. model.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
  332. hf_tokenizer.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
  333. if __name__ == "__main__":
  334. parser = argparse.ArgumentParser()
  335. # Required parameters
  336. parser.add_argument(
  337. "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model."
  338. )
  339. parser.add_argument(
  340. "--classification_head", action="store_true", help="Whether to convert a final classification head."
  341. )
  342. parser.add_argument("--model", default=None, type=str, required=True, help="Name of model to convert.")
  343. parser.add_argument("--push_to_repo", type=str, help="Repo to upload to (including username!).")
  344. parser.add_argument("--auth_token", type=str, help="HuggingFace auth token.")
  345. args = parser.parse_args()
  346. convert_esm_checkpoint_to_pytorch(
  347. args.model, args.pytorch_dump_folder_path, args.classification_head, args.push_to_repo, args.auth_token
  348. )