modeling_esmfold.py 85 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322
  1. # coding=utf-8
  2. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import sys
  17. from dataclasses import dataclass
  18. from functools import partial
  19. from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from torch.nn import LayerNorm
  24. from ...integrations.deepspeed import is_deepspeed_available
  25. from ...modeling_outputs import ModelOutput
  26. from ...utils import (
  27. ContextManagers,
  28. add_start_docstrings,
  29. add_start_docstrings_to_model_forward,
  30. is_scipy_available,
  31. logging,
  32. replace_return_docstrings,
  33. )
  34. from .configuration_esm import EsmConfig
  35. from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
  36. from .openfold_utils import (
  37. OFProtein,
  38. Rigid,
  39. Rotation,
  40. atom14_to_atom37,
  41. chunk_layer,
  42. compute_predicted_aligned_error,
  43. compute_tm,
  44. frames_and_literature_positions_to_atom14_pos,
  45. make_atom14_masks,
  46. residue_constants,
  47. to_pdb,
  48. torsion_angles_to_frames,
  49. )
  50. logger = logging.get_logger(__name__)
  51. _CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
  52. _CONFIG_FOR_DOC = "EsmConfig"
  53. @dataclass
  54. class EsmForProteinFoldingOutput(ModelOutput):
  55. """
  56. Output type of [`EsmForProteinFoldingOutput`].
  57. Args:
  58. frames (`torch.FloatTensor`):
  59. Output frames.
  60. sidechain_frames (`torch.FloatTensor`):
  61. Output sidechain frames.
  62. unnormalized_angles (`torch.FloatTensor`):
  63. Predicted unnormalized backbone and side chain torsion angles.
  64. angles (`torch.FloatTensor`):
  65. Predicted backbone and side chain torsion angles.
  66. positions (`torch.FloatTensor`):
  67. Predicted positions of the backbone and side chain atoms.
  68. states (`torch.FloatTensor`):
  69. Hidden states from the protein folding trunk.
  70. s_s (`torch.FloatTensor`):
  71. Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
  72. s_z (`torch.FloatTensor`):
  73. Pairwise residue embeddings.
  74. distogram_logits (`torch.FloatTensor`):
  75. Input logits to the distogram used to compute residue distances.
  76. lm_logits (`torch.FloatTensor`):
  77. Logits output by the ESM-2 protein language model stem.
  78. aatype (`torch.FloatTensor`):
  79. Input amino acids (AlphaFold2 indices).
  80. atom14_atom_exists (`torch.FloatTensor`):
  81. Whether each atom exists in the atom14 representation.
  82. residx_atom14_to_atom37 (`torch.FloatTensor`):
  83. Mapping between atoms in the atom14 and atom37 representations.
  84. residx_atom37_to_atom14 (`torch.FloatTensor`):
  85. Mapping between atoms in the atom37 and atom14 representations.
  86. atom37_atom_exists (`torch.FloatTensor`):
  87. Whether each atom exists in the atom37 representation.
  88. residue_index (`torch.FloatTensor`):
  89. The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
  90. a sequence of integers from 0 to `sequence_length`.
  91. lddt_head (`torch.FloatTensor`):
  92. Raw outputs from the lddt head used to compute plddt.
  93. plddt (`torch.FloatTensor`):
  94. Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
  95. uncertain, or where the protein structure is disordered.
  96. ptm_logits (`torch.FloatTensor`):
  97. Raw logits used for computing ptm.
  98. ptm (`torch.FloatTensor`):
  99. TM-score output representing the model's high-level confidence in the overall structure.
  100. aligned_confidence_probs (`torch.FloatTensor`):
  101. Per-residue confidence scores for the aligned structure.
  102. predicted_aligned_error (`torch.FloatTensor`):
  103. Predicted error between the model's prediction and the ground truth.
  104. max_predicted_aligned_error (`torch.FloatTensor`):
  105. Per-sample maximum predicted error.
  106. """
  107. frames: torch.FloatTensor = None
  108. sidechain_frames: torch.FloatTensor = None
  109. unnormalized_angles: torch.FloatTensor = None
  110. angles: torch.FloatTensor = None
  111. positions: torch.FloatTensor = None
  112. states: torch.FloatTensor = None
  113. s_s: torch.FloatTensor = None
  114. s_z: torch.FloatTensor = None
  115. distogram_logits: torch.FloatTensor = None
  116. lm_logits: torch.FloatTensor = None
  117. aatype: torch.FloatTensor = None
  118. atom14_atom_exists: torch.FloatTensor = None
  119. residx_atom14_to_atom37: torch.FloatTensor = None
  120. residx_atom37_to_atom14: torch.FloatTensor = None
  121. atom37_atom_exists: torch.FloatTensor = None
  122. residue_index: torch.FloatTensor = None
  123. lddt_head: torch.FloatTensor = None
  124. plddt: torch.FloatTensor = None
  125. ptm_logits: torch.FloatTensor = None
  126. ptm: torch.FloatTensor = None
  127. aligned_confidence_probs: torch.FloatTensor = None
  128. predicted_aligned_error: torch.FloatTensor = None
  129. max_predicted_aligned_error: torch.FloatTensor = None
  130. ESMFOLD_INPUTS_DOCSTRING = r"""
  131. Args:
  132. input_ids (`torch.LongTensor` of shape `({0})`):
  133. Indices of input sequence tokens in the vocabulary.
  134. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  135. [`PreTrainedTokenizer.__call__`] for details.
  136. [What are input IDs?](../glossary#input-ids)
  137. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  138. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  139. - 1 for tokens that are **not masked**,
  140. - 0 for tokens that are **masked**.
  141. [What are attention masks?](../glossary#attention-mask)
  142. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  143. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  144. config.max_position_embeddings - 1]`.
  145. [What are position IDs?](../glossary#position-ids)
  146. masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
  147. Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
  148. num_recycles (`int`, *optional*, defaults to `None`):
  149. Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
  150. consists of passing the output of the folding trunk back in as input to the trunk. During training, the
  151. number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
  152. after each recycle. During inference, num_recycles should be set to the highest value that the model was
  153. trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
  154. used.
  155. """
  156. def is_fp16_enabled():
  157. # Autocast world
  158. fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
  159. fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
  160. return fp16_enabled
  161. def is_deepspeed_initialized():
  162. if is_deepspeed_available():
  163. return False
  164. else:
  165. try:
  166. import deepspeed
  167. # This is not available in all DeepSpeed versions.
  168. return deepspeed.utils.is_initialized()
  169. except Exception:
  170. return False
  171. def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
  172. """
  173. Takes a list of tensors with the following dimensions:
  174. [(d_11, ..., d_1K),
  175. (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
  176. and stack + pads them into a single tensor of:
  177. (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
  178. """
  179. if len(samples) == 0:
  180. return torch.Tensor()
  181. if len({x.dim() for x in samples}) != 1:
  182. raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
  183. (device,) = tuple({x.device for x in samples}) # assumes all on same device
  184. max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
  185. result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
  186. result.fill_(pad_v)
  187. for i in range(len(samples)):
  188. result_i = result[i]
  189. t = samples[i]
  190. result_i[tuple(slice(0, k) for k in t.shape)] = t
  191. return result
  192. def flatten_final_dims(t: torch.Tensor, no_dims: int):
  193. return t.reshape(t.shape[:-no_dims] + (-1,))
  194. def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
  195. zero_index = -1 * len(inds)
  196. first_inds = list(range(len(tensor.shape[:zero_index])))
  197. return tensor.permute(first_inds + [zero_index + i for i in inds])
  198. def dict_multimap(fn, dicts):
  199. first = dicts[0]
  200. new_dict = {}
  201. for k, v in first.items():
  202. all_v = [d[k] for d in dicts]
  203. if isinstance(v, dict):
  204. new_dict[k] = dict_multimap(fn, all_v)
  205. else:
  206. new_dict[k] = fn(all_v)
  207. return new_dict
  208. def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
  209. shape = weights.shape
  210. scale = scale / max(1, shape[1])
  211. if not is_scipy_available():
  212. logger.warning(
  213. "This init requires scipy, but scipy was not found, default to an approximation that might not be"
  214. " equivalent."
  215. )
  216. std = math.sqrt(scale)
  217. torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
  218. else:
  219. from scipy.stats import truncnorm
  220. std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
  221. samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
  222. samples = np.reshape(samples, shape)
  223. weights.copy_(torch.tensor(samples, device=weights.device))
  224. def ipa_point_weights_init_(weights):
  225. with torch.no_grad():
  226. softplus_inverse_1 = 0.541324854612918
  227. weights.fill_(softplus_inverse_1)
  228. class EsmFoldLinear(nn.Linear):
  229. """
  230. A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
  231. Implements the initializers in 1.11.4, plus some additional ones found in the code.
  232. """
  233. def __init__(
  234. self,
  235. in_dim: int,
  236. out_dim: int,
  237. bias: bool = True,
  238. init: str = "default",
  239. init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
  240. ):
  241. """
  242. Args:
  243. in_dim:
  244. The final dimension of inputs to the layer
  245. out_dim:
  246. The final dimension of layer outputs
  247. bias:
  248. Whether to learn an additive bias. True by default
  249. init:
  250. The initializer to use. Choose from:
  251. "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
  252. distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
  253. Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
  254. Overridden by init_fn if the latter is not None.
  255. init_fn:
  256. A custom initializer taking weight and bias as inputs. Overrides init if not None.
  257. """
  258. super().__init__(in_dim, out_dim, bias=bias)
  259. if bias:
  260. with torch.no_grad():
  261. self.bias.fill_(0)
  262. self.init = init
  263. self.init_fn = init_fn
  264. if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
  265. raise ValueError("Invalid init string.")
  266. class EsmFoldLayerNorm(nn.Module):
  267. def __init__(self, c_in, eps=1e-5):
  268. super().__init__()
  269. self.c_in = (c_in,)
  270. self.eps = eps
  271. self.weight = nn.Parameter(torch.ones(c_in))
  272. self.bias = nn.Parameter(torch.zeros(c_in))
  273. def forward(self, x):
  274. d = x.dtype
  275. if d is torch.bfloat16 and not is_deepspeed_initialized():
  276. with torch.cuda.amp.autocast(enabled=False):
  277. out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
  278. else:
  279. out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
  280. return out
  281. @torch.jit.ignore
  282. def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
  283. """
  284. Softmax, but without automatic casting to fp32 when the input is of type bfloat16
  285. """
  286. d = t.dtype
  287. if d is torch.bfloat16 and not is_deepspeed_initialized():
  288. with torch.cuda.amp.autocast(enabled=False):
  289. s = torch.nn.functional.softmax(t, dim=dim)
  290. else:
  291. s = torch.nn.functional.softmax(t, dim=dim)
  292. return s
  293. class EsmFoldAttention(nn.Module):
  294. """
  295. Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
  296. """
  297. def __init__(
  298. self,
  299. c_q: int,
  300. c_k: int,
  301. c_v: int,
  302. c_hidden: int,
  303. no_heads: int,
  304. gating: bool = True,
  305. ):
  306. """
  307. Args:
  308. c_q:
  309. Input dimension of query data
  310. c_k:
  311. Input dimension of key data
  312. c_v:
  313. Input dimension of value data
  314. c_hidden:
  315. Per-head hidden dimension
  316. no_heads:
  317. Number of attention heads
  318. gating:
  319. Whether the output should be gated using query data
  320. """
  321. super().__init__()
  322. self.c_q = c_q
  323. self.c_k = c_k
  324. self.c_v = c_v
  325. self.c_hidden = c_hidden
  326. self.no_heads = no_heads
  327. self.gating = gating
  328. # DISCREPANCY: c_hidden is not the per-head channel dimension, as
  329. # stated in the supplement, but the overall channel dimension.
  330. self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
  331. self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
  332. self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
  333. self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
  334. self.linear_g = None
  335. if self.gating:
  336. self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
  337. self.sigmoid = nn.Sigmoid()
  338. def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  339. # [*, Q/K/V, H * C_hidden]
  340. q = self.linear_q(q_x)
  341. k = self.linear_k(kv_x)
  342. v = self.linear_v(kv_x)
  343. # [*, Q/K, H, C_hidden]
  344. q = q.view(q.shape[:-1] + (self.no_heads, -1))
  345. k = k.view(k.shape[:-1] + (self.no_heads, -1))
  346. v = v.view(v.shape[:-1] + (self.no_heads, -1))
  347. # [*, H, Q/K, C_hidden]
  348. q = q.transpose(-2, -3)
  349. k = k.transpose(-2, -3)
  350. v = v.transpose(-2, -3)
  351. q /= math.sqrt(self.c_hidden)
  352. return q, k, v
  353. def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
  354. if self.linear_g is not None:
  355. g = self.sigmoid(self.linear_g(q_x))
  356. # [*, Q, H, C_hidden]
  357. g = g.view(g.shape[:-1] + (self.no_heads, -1))
  358. o = o * g
  359. # [*, Q, H * C_hidden]
  360. o = flatten_final_dims(o, 2)
  361. # [*, Q, C_q]
  362. o = self.linear_o(o)
  363. return o
  364. def forward(
  365. self,
  366. q_x: torch.Tensor,
  367. kv_x: torch.Tensor,
  368. biases: Optional[List[torch.Tensor]] = None,
  369. use_memory_efficient_kernel: bool = False,
  370. use_lma: bool = False,
  371. lma_q_chunk_size: int = 1024,
  372. lma_kv_chunk_size: int = 4096,
  373. use_flash: bool = False,
  374. flash_mask: Optional[torch.Tensor] = None,
  375. ) -> torch.Tensor:
  376. """
  377. Args:
  378. q_x:
  379. [*, Q, C_q] query data
  380. kv_x:
  381. [*, K, C_k] key data
  382. biases:
  383. List of biases that broadcast to [*, H, Q, K]
  384. use_memory_efficient_kernel:
  385. Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
  386. If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
  387. use_lma:
  388. Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
  389. stock PyTorch implementation is used instead
  390. lma_q_chunk_size:
  391. Query chunk size (for LMA)
  392. lma_kv_chunk_size:
  393. Key/Value chunk size (for LMA)
  394. Returns
  395. [*, Q, C_q] attention update
  396. """
  397. if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
  398. raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
  399. if use_flash and biases is not None:
  400. raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
  401. attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
  402. if sum(attn_options) > 1:
  403. raise ValueError("Choose at most one alternative attention algorithm")
  404. if biases is None:
  405. biases = []
  406. # [*, H, Q/K, C_hidden]
  407. query, key, value = self._prep_qkv(q_x, kv_x)
  408. key = permute_final_dims(key, (1, 0))
  409. # [*, H, Q, K]
  410. output = torch.matmul(query, key)
  411. for b in biases:
  412. output += b
  413. output = softmax_no_cast(output, -1)
  414. # [*, H, Q, C_hidden]
  415. output = torch.matmul(output, value)
  416. output = output.transpose(-2, -3)
  417. output = self._wrap_up(output, q_x)
  418. return output
  419. class EsmFoldTriangleAttention(nn.Module):
  420. def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
  421. """
  422. Args:
  423. c_in:
  424. Input channel dimension
  425. c_hidden:
  426. Overall hidden channel dimension (not per-head)
  427. no_heads:
  428. Number of attention heads
  429. """
  430. super().__init__()
  431. self.c_in = c_in
  432. self.c_hidden = c_hidden
  433. self.no_heads = no_heads
  434. self.starting = starting
  435. self.inf = inf
  436. self.layer_norm = LayerNorm(self.c_in)
  437. self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
  438. self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
  439. @torch.jit.ignore
  440. def _chunk(
  441. self,
  442. x: torch.Tensor,
  443. biases: List[torch.Tensor],
  444. chunk_size: int,
  445. use_memory_efficient_kernel: bool = False,
  446. use_lma: bool = False,
  447. inplace_safe: bool = False,
  448. ) -> torch.Tensor:
  449. "triangle! triangle!"
  450. mha_inputs = {
  451. "q_x": x,
  452. "kv_x": x,
  453. "biases": biases,
  454. }
  455. return chunk_layer(
  456. partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
  457. mha_inputs,
  458. chunk_size=chunk_size,
  459. no_batch_dims=len(x.shape[:-2]),
  460. _out=x if inplace_safe else None,
  461. )
  462. def forward(
  463. self,
  464. x: torch.Tensor,
  465. mask: Optional[torch.Tensor] = None,
  466. chunk_size: Optional[int] = None,
  467. use_memory_efficient_kernel: bool = False,
  468. use_lma: bool = False,
  469. inplace_safe: bool = False,
  470. ) -> torch.Tensor:
  471. """
  472. Args:
  473. x:
  474. [*, I, J, C_in] input tensor (e.g. the pair representation)
  475. Returns:
  476. [*, I, J, C_in] output tensor
  477. """
  478. if mask is None:
  479. # [*, I, J]
  480. mask = x.new_ones(
  481. x.shape[:-1],
  482. )
  483. if not self.starting:
  484. x = x.transpose(-2, -3)
  485. mask = mask.transpose(-1, -2)
  486. # [*, I, J, C_in]
  487. x = self.layer_norm(x)
  488. # [*, I, 1, 1, J]
  489. mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
  490. # [*, H, I, J]
  491. triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
  492. # [*, 1, H, I, J]
  493. triangle_bias = triangle_bias.unsqueeze(-4)
  494. biases = [mask_bias, triangle_bias]
  495. if chunk_size is not None:
  496. x = self._chunk(
  497. x,
  498. biases,
  499. chunk_size,
  500. use_memory_efficient_kernel=use_memory_efficient_kernel,
  501. use_lma=use_lma,
  502. inplace_safe=inplace_safe,
  503. )
  504. else:
  505. x = self.mha(
  506. q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
  507. )
  508. if not self.starting:
  509. x = x.transpose(-2, -3)
  510. return x
  511. class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
  512. """
  513. Implements Algorithms 11 and 12.
  514. """
  515. def __init__(self, config, _outgoing=True):
  516. super().__init__()
  517. c_hidden = config.pairwise_state_dim
  518. self._outgoing = _outgoing
  519. self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
  520. self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  521. self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
  522. self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  523. self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
  524. self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
  525. self.layer_norm_in = LayerNorm(c_hidden)
  526. self.layer_norm_out = LayerNorm(c_hidden)
  527. self.sigmoid = nn.Sigmoid()
  528. def _combine_projections(
  529. self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
  530. ) -> torch.Tensor:
  531. if self._outgoing:
  532. a = permute_final_dims(a, (2, 0, 1))
  533. b = permute_final_dims(b, (2, 1, 0))
  534. else:
  535. a = permute_final_dims(a, (2, 1, 0))
  536. b = permute_final_dims(b, (2, 0, 1))
  537. if _inplace_chunk_size is not None:
  538. # To be replaced by torch vmap
  539. for i in range(0, a.shape[-3], _inplace_chunk_size):
  540. a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
  541. b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
  542. a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
  543. a_chunk,
  544. b_chunk,
  545. )
  546. p = a
  547. else:
  548. p = torch.matmul(a, b)
  549. return permute_final_dims(p, (1, 2, 0))
  550. def _inference_forward(
  551. self,
  552. z: torch.Tensor,
  553. mask: Optional[torch.Tensor] = None,
  554. inplace_chunk_size: Optional[int] = None,
  555. with_add: bool = True,
  556. ):
  557. """
  558. Args:
  559. z:
  560. A [*, N, N, C_z] pair representation
  561. mask:
  562. A [*, N, N] pair mask
  563. inplace_chunk_size:
  564. Size of chunks used in the main computation. Increase to trade memory for speed.
  565. with_add:
  566. If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
  567. Returns:
  568. A reference to the overwritten z
  569. More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
  570. addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
  571. values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
  572. Useful for inference on extremely long sequences.
  573. It works as follows. We will make reference to variables used in the default forward implementation below.
  574. Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
  575. "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
  576. and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
  577. N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
  578. tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
  579. tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
  580. pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
  581. inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
  582. total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
  583. directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
  584. the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
  585. ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
  586. however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
  587. a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
  588. 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
  589. iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
  590. Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
  591. z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
  592. After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
  593. If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
  594. peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
  595. variables.
  596. """
  597. if mask is None:
  598. mask = z.new_ones(z.shape[:-1])
  599. mask = mask.unsqueeze(-1)
  600. def compute_projection_helper(pair, mask, a=True):
  601. if a:
  602. linear_g = self.linear_a_g
  603. linear_p = self.linear_a_p
  604. else:
  605. linear_g = self.linear_b_g
  606. linear_p = self.linear_b_p
  607. pair = self.layer_norm_in(pair)
  608. p = linear_g(pair)
  609. p.sigmoid_()
  610. p *= linear_p(pair)
  611. p *= mask
  612. p = permute_final_dims(p, (2, 0, 1))
  613. return p
  614. def compute_projection(pair, mask, a=True, chunked=True):
  615. need_transpose = self._outgoing ^ a
  616. if not chunked:
  617. p = compute_projection_helper(pair, mask, a)
  618. if need_transpose:
  619. p = p.transpose(-1, -2)
  620. else:
  621. # This computation is chunked so as not to exceed our 2.5x
  622. # budget with a large intermediate tensor
  623. linear_g = self.linear_a_g if a else self.linear_b_g
  624. c = linear_g.bias.shape[-1]
  625. out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
  626. p = pair.new_zeros(out_shape)
  627. for i in range(0, pair.shape[-3], inplace_chunk_size):
  628. pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
  629. pair_chunk = compute_projection_helper(
  630. pair[..., i : i + inplace_chunk_size, :, :],
  631. mask[..., i : i + inplace_chunk_size, :, :],
  632. a,
  633. )
  634. if need_transpose:
  635. pair_chunk = pair_chunk.transpose(-1, -2)
  636. p[..., i : i + inplace_chunk_size] = pair_chunk
  637. else:
  638. p[..., i : i + inplace_chunk_size, :] = pair_chunk
  639. del pair_chunk
  640. return p
  641. # We start by fully manifesting a. In addition to the input, this
  642. # brings total memory consumption to 2x z (disregarding size of chunks)
  643. # [*, N, N, c]
  644. a = compute_projection(z, mask, True, chunked=True)
  645. if inplace_chunk_size is not None:
  646. n = a.shape[-1]
  647. half_n = n // 2 + n % 2
  648. row_dim = -3
  649. col_dim = -2
  650. b_chunk_dim = row_dim if self._outgoing else col_dim
  651. def empty_slicer(t):
  652. return [slice(None) for _ in t.shape]
  653. def slice_tensor(t, start, end, dim):
  654. # Slices start:end from the dim dimension of t
  655. s = empty_slicer(t)
  656. s[dim] = slice(start, end)
  657. return t[s]
  658. def flip_z_cache_(z_cache, z):
  659. # "Reorient" the z_cache (see below), filling it with quadrants
  660. # 3---recovered from the z_cache---and 4---recovered from z---
  661. # of the input tensor z.
  662. quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
  663. z_cache = z_cache.transpose(row_dim, col_dim)
  664. # If n is odd, we need to shrink the z_cache by one row
  665. z_cache = z_cache[..., : (n // 2), :, :]
  666. # Move the 3rd quadrant of z into the
  667. first_half_slicer = empty_slicer(z_cache)
  668. first_half_slicer[col_dim] = slice(0, half_n)
  669. z_cache[first_half_slicer] = quadrant_3
  670. # Get the fourth quadrant of z
  671. quadrant_4 = slice_tensor(z, half_n, None, row_dim)
  672. quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
  673. # Insert said quadrant into the rotated z-cache
  674. quadrant_3_slicer = empty_slicer(z_cache)
  675. quadrant_3_slicer[col_dim] = slice(half_n, None)
  676. z_cache[quadrant_3_slicer] = quadrant_4
  677. return z_cache
  678. # Initialize the z cache to the left half of z.
  679. z_cache_shape = list(z.shape)
  680. z_cache_shape[col_dim] = half_n
  681. z_cache = z.new_zeros(z_cache_shape)
  682. z_cache_slicer = empty_slicer(z_cache)
  683. z_cache_slicer[col_dim] = slice(0, half_n)
  684. z_cache.copy_(z[z_cache_slicer])
  685. z_cache_rotated = False
  686. # We need to reorient the z-cache at the halfway point, and we
  687. # don't want a single chunk to straddle that point. We contract one
  688. # of the chunks in the middle to address that problem.
  689. i_range = list(range(0, half_n, inplace_chunk_size))
  690. initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
  691. after_half = list(range(half_n, n, inplace_chunk_size))
  692. after_half_offsets = [inplace_chunk_size for _ in after_half]
  693. combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
  694. for i, offset in combined_range_with_offsets:
  695. if not z_cache_rotated and i >= half_n:
  696. z_cache = flip_z_cache_(z_cache, z)
  697. z_cache_rotated = True
  698. z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
  699. mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
  700. z_chunk_b = z_chunk_b.clone()
  701. if b_chunk_dim == col_dim:
  702. z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
  703. else: # b_chunk_dim == row_dim
  704. # In this case, the b-dimension (b_chunk_dim) is partially
  705. # overwritten at the end of each iteration. We need to
  706. # restore the missing component from the z-cache.
  707. if not z_cache_rotated:
  708. z_chunk_slicer = empty_slicer(z_chunk_b)
  709. z_chunk_slicer[col_dim] = slice(0, half_n)
  710. z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
  711. else:
  712. z_cache_offset = i - half_n
  713. z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
  714. b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
  715. del z_chunk_b
  716. x_chunk = torch.matmul(a, b_chunk)
  717. x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
  718. x_chunk = self.layer_norm_out(x_chunk)
  719. x_chunk = self.linear_z(x_chunk)
  720. # The g dimension (col_dim) is parallel to and ahead of the
  721. # overwrites in z. We can extract the g chunk normally.
  722. z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
  723. g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
  724. g_chunk.sigmoid_()
  725. del z_chunk_g
  726. x_chunk *= g_chunk
  727. # Write the columns into z in-place
  728. z_slicer = empty_slicer(z)
  729. z_slicer[col_dim] = slice(i, i + offset)
  730. if with_add:
  731. z[z_slicer] += x_chunk
  732. else:
  733. z[z_slicer] = x_chunk
  734. else:
  735. b = compute_projection(z, mask, False, False)
  736. x = torch.matmul(a, b)
  737. x = self.layer_norm_out(x)
  738. x = self.linear_z(x)
  739. g = self.linear_g(z)
  740. g.sigmoid_()
  741. x *= g
  742. if with_add:
  743. z += x
  744. else:
  745. z = x
  746. return z
  747. def forward(
  748. self,
  749. z: torch.Tensor,
  750. mask: Optional[torch.Tensor] = None,
  751. inplace_safe: bool = False,
  752. _add_with_inplace: bool = False,
  753. _inplace_chunk_size: Optional[int] = 256,
  754. ) -> torch.Tensor:
  755. """
  756. Args:
  757. x:
  758. [*, N_res, N_res, C_z] input tensor
  759. mask:
  760. [*, N_res, N_res] input mask
  761. Returns:
  762. [*, N_res, N_res, C_z] output tensor
  763. """
  764. if inplace_safe:
  765. x = self._inference_forward(
  766. z,
  767. mask,
  768. inplace_chunk_size=_inplace_chunk_size,
  769. with_add=_add_with_inplace,
  770. )
  771. return x
  772. if mask is None:
  773. mask = z.new_ones(z.shape[:-1])
  774. mask = mask.unsqueeze(-1)
  775. z = self.layer_norm_in(z)
  776. a = mask
  777. a = a * self.sigmoid(self.linear_a_g(z))
  778. a = a * self.linear_a_p(z)
  779. b = mask
  780. b = b * self.sigmoid(self.linear_b_g(z))
  781. b = b * self.linear_b_p(z)
  782. if is_fp16_enabled():
  783. with torch.cuda.amp.autocast(enabled=False):
  784. x = self._combine_projections(a.float(), b.float())
  785. else:
  786. x = self._combine_projections(a, b)
  787. del a, b
  788. x = self.layer_norm_out(x)
  789. x = self.linear_z(x)
  790. g = self.sigmoid(self.linear_g(z))
  791. x = x * g
  792. return x
  793. class EsmFoldPreTrainedModel(EsmPreTrainedModel):
  794. """
  795. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  796. models.
  797. """
  798. # Subclass `EsMPreTrainedModel` to deal with special init
  799. def _init_weights(self, module):
  800. """Initialize the weights"""
  801. if isinstance(module, EsmFoldLinear):
  802. with torch.no_grad():
  803. if module.init_fn is not None:
  804. module.init_fn(module.weight, module.bias)
  805. elif module.init == "default":
  806. trunc_normal_init_(module.weight, scale=1.0)
  807. elif module.init == "relu":
  808. trunc_normal_init_(module.weight, scale=2.0)
  809. elif module.init == "glorot":
  810. nn.init.xavier_uniform_(module.weight, gain=1)
  811. elif module.init == "gating":
  812. module.weight.fill_(0.0)
  813. if module.bias:
  814. module.bias.fill_(1.0)
  815. elif module.init == "normal":
  816. torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
  817. elif module.init == "final":
  818. module.weight.fill_(0.0)
  819. elif isinstance(module, EsmFoldInvariantPointAttention):
  820. ipa_point_weights_init_(module.head_weights)
  821. elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
  822. torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
  823. torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
  824. torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
  825. torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
  826. torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
  827. torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
  828. torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
  829. torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
  830. torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
  831. torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
  832. torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
  833. torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
  834. torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
  835. torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
  836. torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
  837. torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
  838. torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
  839. else:
  840. super()._init_weights(module)
  841. class EsmFoldSelfAttention(nn.Module):
  842. def __init__(self, embed_dim, num_heads, head_width, gated=False):
  843. super().__init__()
  844. assert embed_dim == num_heads * head_width
  845. self.embed_dim = embed_dim
  846. self.num_heads = num_heads
  847. self.head_width = head_width
  848. self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
  849. self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
  850. self.gated = gated
  851. if gated:
  852. self.g_proj = nn.Linear(embed_dim, embed_dim)
  853. torch.nn.init.zeros_(self.g_proj.weight)
  854. torch.nn.init.ones_(self.g_proj.bias)
  855. self.rescale_factor = self.head_width**-0.5
  856. torch.nn.init.zeros_(self.o_proj.bias)
  857. def forward(self, x, mask=None, bias=None, indices=None):
  858. """
  859. Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
  860. use mask.
  861. Inputs:
  862. x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
  863. x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
  864. Outputs:
  865. sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
  866. """
  867. t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
  868. t = t.permute(0, 2, 1, 3)
  869. q, k, v = t.chunk(3, dim=-1)
  870. q = self.rescale_factor * q
  871. a = torch.einsum("...qc,...kc->...qk", q, k)
  872. # Add external attention bias.
  873. if bias is not None:
  874. a = a + bias.permute(0, 3, 1, 2)
  875. # Do not attend to padding tokens.
  876. if mask is not None:
  877. mask = mask[:, None, None]
  878. a = a.masked_fill(mask == False, -np.inf) # noqa: E712
  879. a = nn.functional.softmax(a, dim=-1)
  880. y = torch.einsum("...hqk,...hkc->...qhc", a, v)
  881. y = y.reshape(*y.shape[:2], -1)
  882. if self.gated:
  883. y = self.g_proj(x).sigmoid() * y
  884. y = self.o_proj(y)
  885. return y, a.permute(0, 3, 1, 2)
  886. class EsmFoldDropout(nn.Module):
  887. """
  888. Implementation of dropout with the ability to share the dropout mask along a particular dimension.
  889. """
  890. def __init__(self, r: float, batch_dim: Union[int, List[int]]):
  891. super().__init__()
  892. self.r = r
  893. if isinstance(batch_dim, int):
  894. batch_dim = [batch_dim]
  895. self.batch_dim = batch_dim
  896. self.dropout = nn.Dropout(self.r)
  897. def forward(self, x: torch.Tensor) -> torch.Tensor:
  898. shape = list(x.shape)
  899. if self.batch_dim is not None:
  900. for bd in self.batch_dim:
  901. shape[bd] = 1
  902. return x * self.dropout(x.new_ones(shape))
  903. class EsmFoldSequenceToPair(nn.Module):
  904. def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
  905. super().__init__()
  906. self.layernorm = nn.LayerNorm(sequence_state_dim)
  907. self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
  908. self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
  909. torch.nn.init.zeros_(self.proj.bias)
  910. torch.nn.init.zeros_(self.o_proj.bias)
  911. def forward(self, sequence_state):
  912. """
  913. Inputs:
  914. sequence_state: B x L x sequence_state_dim
  915. Output:
  916. pairwise_state: B x L x L x pairwise_state_dim
  917. Intermediate state:
  918. B x L x L x 2*inner_dim
  919. """
  920. assert len(sequence_state.shape) == 3
  921. s = self.layernorm(sequence_state)
  922. s = self.proj(s)
  923. q, k = s.chunk(2, dim=-1)
  924. prod = q[:, None, :, :] * k[:, :, None, :]
  925. diff = q[:, None, :, :] - k[:, :, None, :]
  926. x = torch.cat([prod, diff], dim=-1)
  927. x = self.o_proj(x)
  928. return x
  929. class EsmFoldPairToSequence(nn.Module):
  930. def __init__(self, pairwise_state_dim, num_heads):
  931. super().__init__()
  932. self.layernorm = nn.LayerNorm(pairwise_state_dim)
  933. self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
  934. def forward(self, pairwise_state):
  935. """
  936. Inputs:
  937. pairwise_state: B x L x L x pairwise_state_dim
  938. Output:
  939. pairwise_bias: B x L x L x num_heads
  940. """
  941. assert len(pairwise_state.shape) == 4
  942. z = self.layernorm(pairwise_state)
  943. pairwise_bias = self.linear(z)
  944. return pairwise_bias
  945. class EsmFoldResidueMLP(nn.Module):
  946. def __init__(self, embed_dim, inner_dim, dropout=0):
  947. super().__init__()
  948. self.mlp = nn.Sequential(
  949. nn.LayerNorm(embed_dim),
  950. nn.Linear(embed_dim, inner_dim),
  951. nn.ReLU(),
  952. nn.Linear(inner_dim, embed_dim),
  953. nn.Dropout(dropout),
  954. )
  955. def forward(self, x):
  956. return x + self.mlp(x)
  957. class EsmFoldTriangularSelfAttentionBlock(nn.Module):
  958. def __init__(self, config):
  959. super().__init__()
  960. self.config = config
  961. sequence_state_dim = config.sequence_state_dim
  962. pairwise_state_dim = config.pairwise_state_dim
  963. sequence_num_heads = sequence_state_dim // config.sequence_head_width
  964. pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
  965. self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
  966. self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
  967. self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
  968. self.seq_attention = EsmFoldSelfAttention(
  969. sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
  970. )
  971. self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
  972. self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
  973. self.tri_att_start = EsmFoldTriangleAttention(
  974. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
  975. )
  976. self.tri_att_end = EsmFoldTriangleAttention(
  977. pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
  978. )
  979. self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
  980. self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
  981. self.drop = nn.Dropout(config.dropout)
  982. self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
  983. self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
  984. def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
  985. """
  986. Inputs:
  987. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
  988. tensor of valid positions
  989. Output:
  990. sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
  991. """
  992. if len(sequence_state.shape) != 3:
  993. raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
  994. if len(pairwise_state.shape) != 4:
  995. raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
  996. if mask is not None and len(mask.shape) != 2:
  997. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  998. batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
  999. pairwise_state_dim = pairwise_state.shape[3]
  1000. if sequence_state_dim != self.config.sequence_state_dim:
  1001. raise ValueError(
  1002. "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
  1003. f"{sequence_state_dim} != {self.config.sequence_state_dim}."
  1004. )
  1005. if pairwise_state_dim != self.config.pairwise_state_dim:
  1006. raise ValueError(
  1007. "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
  1008. f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
  1009. )
  1010. if batch_dim != pairwise_state.shape[0]:
  1011. raise ValueError(
  1012. f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
  1013. f"{pairwise_state.shape[0]}."
  1014. )
  1015. if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
  1016. raise ValueError(
  1017. f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
  1018. f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
  1019. )
  1020. # Update sequence state
  1021. bias = self.pair_to_sequence(pairwise_state)
  1022. # Self attention with bias + mlp.
  1023. y = self.layernorm_1(sequence_state)
  1024. y, _ = self.seq_attention(y, mask=mask, bias=bias)
  1025. sequence_state = sequence_state + self.drop(y)
  1026. sequence_state = self.mlp_seq(sequence_state)
  1027. # Update pairwise state
  1028. pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
  1029. # Axial attention with triangular bias.
  1030. tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
  1031. pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
  1032. pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
  1033. pairwise_state = pairwise_state + self.row_drop(
  1034. self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  1035. )
  1036. pairwise_state = pairwise_state + self.col_drop(
  1037. self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
  1038. )
  1039. # MLP over pairs.
  1040. pairwise_state = self.mlp_pair(pairwise_state)
  1041. return sequence_state, pairwise_state
  1042. class EsmCategoricalMixture:
  1043. def __init__(self, param, bins=50, start=0, end=1):
  1044. # All tensors are of shape ..., bins.
  1045. self.logits = param
  1046. bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
  1047. self.v_bins = (bins[:-1] + bins[1:]) / 2
  1048. def log_prob(self, true):
  1049. # Shapes are:
  1050. # self.probs: ... x bins
  1051. # true : ...
  1052. true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
  1053. nll = self.logits.log_softmax(-1)
  1054. return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
  1055. def mean(self):
  1056. return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
  1057. def categorical_lddt(logits, bins=50):
  1058. # Logits are ..., 37, bins.
  1059. return EsmCategoricalMixture(logits, bins=bins).mean()
  1060. def get_axial_mask(mask):
  1061. """
  1062. Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
  1063. Input:
  1064. mask: B x L tensor of booleans
  1065. Output:
  1066. mask: B x L x L tensor of booleans
  1067. """
  1068. if mask is None:
  1069. return None
  1070. if len(mask.shape) != 2:
  1071. raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
  1072. batch_dim, seq_dim = mask.shape
  1073. m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
  1074. m = m.reshape(batch_dim * seq_dim, seq_dim)
  1075. return m
  1076. class EsmFoldRelativePosition(nn.Module):
  1077. def __init__(self, config):
  1078. super().__init__()
  1079. self.bins = config.position_bins
  1080. # Note an additional offset is used so that the 0th position
  1081. # is reserved for masked pairs.
  1082. self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
  1083. def forward(self, residue_index, mask=None):
  1084. """
  1085. Input:
  1086. residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
  1087. Output:
  1088. pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
  1089. """
  1090. if residue_index.dtype != torch.long:
  1091. raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
  1092. if mask is not None and residue_index.shape != mask.shape:
  1093. raise ValueError(
  1094. f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
  1095. )
  1096. diff = residue_index[:, None, :] - residue_index[:, :, None]
  1097. diff = diff.clamp(-self.bins, self.bins)
  1098. diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
  1099. if mask is not None:
  1100. mask = mask[:, None, :] * mask[:, :, None]
  1101. diff[mask == False] = 0 # noqa: E712
  1102. output = self.embedding(diff)
  1103. return output
  1104. class EsmFoldAngleResnetBlock(nn.Module):
  1105. def __init__(self, config):
  1106. super().__init__()
  1107. self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
  1108. self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
  1109. self.relu = nn.ReLU()
  1110. def forward(self, a: torch.Tensor) -> torch.Tensor:
  1111. s_initial = a
  1112. a = self.relu(a)
  1113. a = self.linear_1(a)
  1114. a = self.relu(a)
  1115. a = self.linear_2(a)
  1116. return a + s_initial
  1117. class EsmFoldAngleResnet(nn.Module):
  1118. """
  1119. Implements Algorithm 20, lines 11-14
  1120. """
  1121. def __init__(self, config):
  1122. super().__init__()
  1123. self.config = config
  1124. self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1125. self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
  1126. self.layers = nn.ModuleList()
  1127. for _ in range(config.num_resnet_blocks):
  1128. layer = EsmFoldAngleResnetBlock(config)
  1129. self.layers.append(layer)
  1130. self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
  1131. self.relu = nn.ReLU()
  1132. def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  1133. """
  1134. Args:
  1135. s:
  1136. [*, C_hidden] single embedding
  1137. s_initial:
  1138. [*, C_hidden] single embedding as of the start of the StructureModule
  1139. Returns:
  1140. [*, no_angles, 2] predicted angles
  1141. """
  1142. # NOTE: The ReLU's applied to the inputs are absent from the supplement
  1143. # pseudocode but present in the source. For maximal compatibility with
  1144. # the pretrained weights, I'm going with the source.
  1145. # [*, C_hidden]
  1146. s_initial = self.relu(s_initial)
  1147. s_initial = self.linear_initial(s_initial)
  1148. s = self.relu(s)
  1149. s = self.linear_in(s)
  1150. s = s + s_initial
  1151. for l in self.layers:
  1152. s = l(s)
  1153. s = self.relu(s)
  1154. # [*, no_angles * 2]
  1155. s = self.linear_out(s)
  1156. # [*, no_angles, 2]
  1157. s = s.view(s.shape[:-1] + (-1, 2))
  1158. unnormalized_s = s
  1159. norm_denom = torch.sqrt(
  1160. torch.clamp(
  1161. torch.sum(s**2, dim=-1, keepdim=True),
  1162. min=self.config.epsilon,
  1163. )
  1164. )
  1165. s = s / norm_denom
  1166. return unnormalized_s, s
  1167. class EsmFoldInvariantPointAttention(nn.Module):
  1168. """
  1169. Implements Algorithm 22.
  1170. """
  1171. def __init__(self, config):
  1172. super().__init__()
  1173. self.config = config
  1174. c_s = config.sequence_dim
  1175. c_z = config.pairwise_dim
  1176. self.hidden_dim = config.ipa_dim
  1177. self.num_heads = config.num_heads_ipa
  1178. self.num_qk_points = config.num_qk_points
  1179. self.num_v_points = config.num_v_points
  1180. # These linear layers differ from their specifications in the
  1181. # supplement. There, they lack bias and use Glorot initialization.
  1182. # Here as in the official source, they have bias and use the default
  1183. # Lecun initialization.
  1184. hc = config.ipa_dim * config.num_heads_ipa
  1185. self.linear_q = EsmFoldLinear(c_s, hc)
  1186. self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
  1187. hpq = config.num_heads_ipa * config.num_qk_points * 3
  1188. self.linear_q_points = EsmFoldLinear(c_s, hpq)
  1189. hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
  1190. self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
  1191. self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
  1192. self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
  1193. concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
  1194. self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
  1195. self.softmax = nn.Softmax(dim=-1)
  1196. self.softplus = nn.Softplus()
  1197. def forward(
  1198. self,
  1199. s: torch.Tensor,
  1200. z: Optional[torch.Tensor],
  1201. r: Rigid,
  1202. mask: torch.Tensor,
  1203. _offload_inference: bool = False,
  1204. _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
  1205. ) -> torch.Tensor:
  1206. """
  1207. Args:
  1208. s:
  1209. [*, N_res, C_s] single representation
  1210. z:
  1211. [*, N_res, N_res, C_z] pair representation
  1212. r:
  1213. [*, N_res] transformation object
  1214. mask:
  1215. [*, N_res] mask
  1216. Returns:
  1217. [*, N_res, C_s] single representation update
  1218. """
  1219. z = [z]
  1220. #######################################
  1221. # Generate scalar and point activations
  1222. #######################################
  1223. # [*, N_res, H * C_hidden]
  1224. q = self.linear_q(s)
  1225. kv = self.linear_kv(s)
  1226. # [*, N_res, H, C_hidden]
  1227. q = q.view(q.shape[:-1] + (self.num_heads, -1))
  1228. # [*, N_res, H, 2 * C_hidden]
  1229. kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
  1230. # [*, N_res, H, C_hidden]
  1231. k, v = torch.split(kv, self.hidden_dim, dim=-1)
  1232. # [*, N_res, H * P_q * 3]
  1233. q_pts = self.linear_q_points(s)
  1234. # This is kind of clunky, but it's how the original does it
  1235. # [*, N_res, H * P_q, 3]
  1236. q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
  1237. q_pts = torch.stack(q_pts, dim=-1)
  1238. q_pts = r[..., None].apply(q_pts)
  1239. # [*, N_res, H, P_q, 3]
  1240. q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
  1241. # [*, N_res, H * (P_q + P_v) * 3]
  1242. kv_pts = self.linear_kv_points(s)
  1243. # [*, N_res, H * (P_q + P_v), 3]
  1244. kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
  1245. kv_pts = torch.stack(kv_pts, dim=-1)
  1246. kv_pts = r[..., None].apply(kv_pts)
  1247. # [*, N_res, H, (P_q + P_v), 3]
  1248. kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
  1249. # [*, N_res, H, P_q/P_v, 3]
  1250. k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
  1251. ##########################
  1252. # Compute attention scores
  1253. ##########################
  1254. # [*, N_res, N_res, H]
  1255. b = self.linear_b(z[0])
  1256. if _offload_inference:
  1257. assert sys.getrefcount(z[0]) == 2
  1258. z[0] = z[0].cpu()
  1259. # [*, H, N_res, N_res]
  1260. if is_fp16_enabled():
  1261. with torch.cuda.amp.autocast(enabled=False):
  1262. a = torch.matmul(
  1263. permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
  1264. permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
  1265. )
  1266. else:
  1267. a = torch.matmul(
  1268. permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
  1269. permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
  1270. )
  1271. a *= math.sqrt(1.0 / (3 * self.hidden_dim))
  1272. a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
  1273. # [*, N_res, N_res, H, P_q, 3]
  1274. pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
  1275. pt_att = pt_att**2
  1276. # [*, N_res, N_res, H, P_q]
  1277. pt_att = sum(torch.unbind(pt_att, dim=-1))
  1278. head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
  1279. head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
  1280. pt_att = pt_att * head_weights
  1281. # [*, N_res, N_res, H]
  1282. pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
  1283. # [*, N_res, N_res]
  1284. square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
  1285. square_mask = self.config.inf * (square_mask - 1)
  1286. # [*, H, N_res, N_res]
  1287. pt_att = permute_final_dims(pt_att, (2, 0, 1))
  1288. a = a + pt_att
  1289. a = a + square_mask.unsqueeze(-3)
  1290. a = self.softmax(a)
  1291. ################
  1292. # Compute output
  1293. ################
  1294. # [*, N_res, H, C_hidden]
  1295. o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
  1296. # [*, N_res, H * C_hidden]
  1297. o = flatten_final_dims(o, 2)
  1298. # [*, H, 3, N_res, P_v]
  1299. o_pt = torch.sum(
  1300. (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
  1301. dim=-2,
  1302. )
  1303. # [*, N_res, H, P_v, 3]
  1304. o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
  1305. o_pt = r[..., None, None].invert_apply(o_pt)
  1306. # [*, N_res, H * P_v]
  1307. o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
  1308. # [*, N_res, H * P_v, 3]
  1309. o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
  1310. if _offload_inference:
  1311. z[0] = z[0].to(o_pt.device)
  1312. # [*, N_res, H, C_z]
  1313. o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
  1314. # [*, N_res, H * C_z]
  1315. o_pair = flatten_final_dims(o_pair, 2)
  1316. # [*, N_res, C_s]
  1317. s = self.linear_out(
  1318. torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
  1319. )
  1320. return s
  1321. class EsmFoldBackboneUpdate(nn.Module):
  1322. """
  1323. Implements part of Algorithm 23.
  1324. """
  1325. def __init__(self, config):
  1326. super().__init__()
  1327. self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
  1328. def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  1329. """
  1330. Args:
  1331. [*, N_res, C_s] single representation
  1332. Returns:
  1333. [*, N_res, 6] update vector
  1334. """
  1335. # [*, 6]
  1336. update = self.linear(s)
  1337. return update
  1338. class EsmFoldStructureModuleTransitionLayer(nn.Module):
  1339. def __init__(self, config):
  1340. super().__init__()
  1341. self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1342. self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
  1343. self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
  1344. self.relu = nn.ReLU()
  1345. def forward(self, s):
  1346. s_initial = s
  1347. s = self.linear_1(s)
  1348. s = self.relu(s)
  1349. s = self.linear_2(s)
  1350. s = self.relu(s)
  1351. s = self.linear_3(s)
  1352. s = s + s_initial
  1353. return s
  1354. class EsmFoldStructureModuleTransition(nn.Module):
  1355. def __init__(self, config):
  1356. super().__init__()
  1357. self.config = config
  1358. self.layers = nn.ModuleList()
  1359. for _ in range(config.num_transition_layers):
  1360. l = EsmFoldStructureModuleTransitionLayer(config)
  1361. self.layers.append(l)
  1362. self.dropout = nn.Dropout(config.dropout_rate)
  1363. self.layer_norm = LayerNorm(config.sequence_dim)
  1364. def forward(self, s):
  1365. for l in self.layers:
  1366. s = l(s)
  1367. s = self.dropout(s)
  1368. s = self.layer_norm(s)
  1369. return s
  1370. class EsmFoldStructureModule(nn.Module):
  1371. def __init__(self, config):
  1372. super().__init__()
  1373. self.config = config
  1374. # Buffers to be lazily initialized later
  1375. # self.default_frames
  1376. # self.group_idx
  1377. # self.atom_mask
  1378. # self.lit_positions
  1379. self.layer_norm_s = LayerNorm(config.sequence_dim)
  1380. self.layer_norm_z = LayerNorm(config.pairwise_dim)
  1381. self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
  1382. self.ipa = EsmFoldInvariantPointAttention(config)
  1383. self.ipa_dropout = nn.Dropout(config.dropout_rate)
  1384. self.layer_norm_ipa = LayerNorm(config.sequence_dim)
  1385. self.transition = EsmFoldStructureModuleTransition(config)
  1386. self.bb_update = EsmFoldBackboneUpdate(config)
  1387. self.angle_resnet = EsmFoldAngleResnet(config)
  1388. def forward(
  1389. self,
  1390. evoformer_output_dict,
  1391. aatype,
  1392. mask=None,
  1393. _offload_inference=False,
  1394. ):
  1395. """
  1396. Args:
  1397. evoformer_output_dict:
  1398. Dictionary containing:
  1399. "single":
  1400. [*, N_res, C_s] single representation
  1401. "pair":
  1402. [*, N_res, N_res, C_z] pair representation
  1403. aatype:
  1404. [*, N_res] amino acid indices
  1405. mask:
  1406. Optional [*, N_res] sequence mask
  1407. Returns:
  1408. A dictionary of outputs
  1409. """
  1410. s = evoformer_output_dict["single"]
  1411. if mask is None:
  1412. # [*, N]
  1413. mask = s.new_ones(s.shape[:-1])
  1414. # [*, N, C_s]
  1415. s = self.layer_norm_s(s)
  1416. # [*, N, N, C_z]
  1417. z = self.layer_norm_z(evoformer_output_dict["pair"])
  1418. z_reference_list = None
  1419. if _offload_inference:
  1420. assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
  1421. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
  1422. z_reference_list = [z]
  1423. z = None
  1424. # [*, N, C_s]
  1425. s_initial = s
  1426. s = self.linear_in(s)
  1427. # [*, N]
  1428. rigids = Rigid.identity(
  1429. s.shape[:-1],
  1430. s.dtype,
  1431. s.device,
  1432. self.training,
  1433. fmt="quat",
  1434. )
  1435. outputs = []
  1436. for i in range(self.config.num_blocks):
  1437. # [*, N, C_s]
  1438. s = s + self.ipa(
  1439. s,
  1440. z,
  1441. rigids,
  1442. mask,
  1443. _offload_inference=_offload_inference,
  1444. _z_reference_list=z_reference_list,
  1445. )
  1446. s = self.ipa_dropout(s)
  1447. s = self.layer_norm_ipa(s)
  1448. s = self.transition(s)
  1449. # [*, N]
  1450. rigids = rigids.compose_q_update_vec(self.bb_update(s))
  1451. # To hew as closely as possible to AlphaFold, we convert our
  1452. # quaternion-based transformations to rotation-matrix ones
  1453. # here
  1454. backb_to_global = Rigid(
  1455. Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
  1456. rigids.get_trans(),
  1457. )
  1458. backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
  1459. # [*, N, 7, 2]
  1460. unnormalized_angles, angles = self.angle_resnet(s, s_initial)
  1461. all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
  1462. pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
  1463. scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
  1464. preds = {
  1465. "frames": scaled_rigids.to_tensor_7(),
  1466. "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
  1467. "unnormalized_angles": unnormalized_angles,
  1468. "angles": angles,
  1469. "positions": pred_xyz,
  1470. "states": s,
  1471. }
  1472. outputs.append(preds)
  1473. rigids = rigids.stop_rot_gradient()
  1474. del z, z_reference_list
  1475. if _offload_inference:
  1476. evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
  1477. outputs = dict_multimap(torch.stack, outputs)
  1478. outputs["single"] = s
  1479. return outputs
  1480. def _init_residue_constants(self, float_dtype, device):
  1481. if not hasattr(self, "default_frames"):
  1482. self.register_buffer(
  1483. "default_frames",
  1484. torch.tensor(
  1485. residue_constants.restype_rigid_group_default_frame,
  1486. dtype=float_dtype,
  1487. device=device,
  1488. requires_grad=False,
  1489. ),
  1490. persistent=False,
  1491. )
  1492. if not hasattr(self, "group_idx"):
  1493. self.register_buffer(
  1494. "group_idx",
  1495. torch.tensor(
  1496. residue_constants.restype_atom14_to_rigid_group,
  1497. device=device,
  1498. requires_grad=False,
  1499. ),
  1500. persistent=False,
  1501. )
  1502. if not hasattr(self, "atom_mask"):
  1503. self.register_buffer(
  1504. "atom_mask",
  1505. torch.tensor(
  1506. residue_constants.restype_atom14_mask,
  1507. dtype=float_dtype,
  1508. device=device,
  1509. requires_grad=False,
  1510. ),
  1511. persistent=False,
  1512. )
  1513. if not hasattr(self, "lit_positions"):
  1514. self.register_buffer(
  1515. "lit_positions",
  1516. torch.tensor(
  1517. residue_constants.restype_atom14_rigid_group_positions,
  1518. dtype=float_dtype,
  1519. device=device,
  1520. requires_grad=False,
  1521. ),
  1522. persistent=False,
  1523. )
  1524. def torsion_angles_to_frames(self, r, alpha, f):
  1525. # Lazily initialize the residue constants on the correct device
  1526. self._init_residue_constants(alpha.dtype, alpha.device)
  1527. # Separated purely to make testing less annoying
  1528. return torsion_angles_to_frames(r, alpha, f, self.default_frames)
  1529. def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
  1530. # Lazily initialize the residue constants on the correct device
  1531. self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
  1532. return frames_and_literature_positions_to_atom14_pos(
  1533. r,
  1534. f,
  1535. self.default_frames,
  1536. self.group_idx,
  1537. self.atom_mask,
  1538. self.lit_positions,
  1539. )
  1540. class EsmFoldingTrunk(nn.Module):
  1541. def __init__(self, config):
  1542. super().__init__()
  1543. self.config = config
  1544. c_s = config.sequence_state_dim
  1545. c_z = config.pairwise_state_dim
  1546. self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
  1547. self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
  1548. self.recycle_bins = 15
  1549. self.recycle_s_norm = nn.LayerNorm(c_s)
  1550. self.recycle_z_norm = nn.LayerNorm(c_z)
  1551. self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
  1552. self.recycle_disto.weight[0].detach().zero_()
  1553. self.structure_module = EsmFoldStructureModule(config.structure_module)
  1554. self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
  1555. self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
  1556. self.chunk_size = config.chunk_size
  1557. def set_chunk_size(self, chunk_size):
  1558. # This parameter means the axial attention will be computed
  1559. # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
  1560. # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
  1561. # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
  1562. self.chunk_size = chunk_size
  1563. def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
  1564. """
  1565. Inputs:
  1566. seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
  1567. x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
  1568. Output:
  1569. predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
  1570. """
  1571. device = seq_feats.device
  1572. s_s_0 = seq_feats
  1573. s_z_0 = pair_feats
  1574. if no_recycles is None:
  1575. no_recycles = self.config.max_recycles
  1576. else:
  1577. if no_recycles < 0:
  1578. raise ValueError("Number of recycles must not be negative.")
  1579. no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
  1580. def trunk_iter(s, z, residx, mask):
  1581. z = z + self.pairwise_positional_embedding(residx, mask=mask)
  1582. for block in self.blocks:
  1583. s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
  1584. return s, z
  1585. s_s = s_s_0
  1586. s_z = s_z_0
  1587. recycle_s = torch.zeros_like(s_s)
  1588. recycle_z = torch.zeros_like(s_z)
  1589. recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
  1590. for recycle_idx in range(no_recycles):
  1591. with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
  1592. # === Recycling ===
  1593. recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
  1594. recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
  1595. recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
  1596. s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
  1597. # === Structure module ===
  1598. structure = self.structure_module(
  1599. {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
  1600. true_aa,
  1601. mask.float(),
  1602. )
  1603. recycle_s = s_s
  1604. recycle_z = s_z
  1605. # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
  1606. recycle_bins = EsmFoldingTrunk.distogram(
  1607. structure["positions"][-1][:, :, :3],
  1608. 3.375,
  1609. 21.375,
  1610. self.recycle_bins,
  1611. )
  1612. structure["s_s"] = s_s
  1613. structure["s_z"] = s_z
  1614. return structure
  1615. @staticmethod
  1616. def distogram(coords, min_bin, max_bin, num_bins):
  1617. # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
  1618. boundaries = torch.linspace(
  1619. min_bin,
  1620. max_bin,
  1621. num_bins - 1,
  1622. device=coords.device,
  1623. )
  1624. boundaries = boundaries**2
  1625. N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
  1626. # Infer CB coordinates.
  1627. b = CA - N
  1628. c = C - CA
  1629. a = b.cross(c, dim=-1)
  1630. CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
  1631. dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
  1632. bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
  1633. return bins
  1634. # TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
  1635. # the outputs for downstream use.
  1636. @add_start_docstrings(
  1637. """
  1638. ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
  1639. by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
  1640. the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
  1641. protein(s).
  1642. """,
  1643. ESM_START_DOCSTRING,
  1644. )
  1645. class EsmForProteinFolding(EsmPreTrainedModel):
  1646. _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
  1647. def __init__(self, config):
  1648. super().__init__(config)
  1649. self.config = config
  1650. self.distogram_bins = 64
  1651. self.esm = EsmModel(config, add_pooling_layer=False)
  1652. self.esm.requires_grad_(False)
  1653. if self.config.esmfold_config.fp16_esm:
  1654. self.esm.half()
  1655. self.esm_feats = self.config.hidden_size
  1656. self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
  1657. self.esm_layers = self.config.num_hidden_layers
  1658. self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
  1659. self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
  1660. trunk_config = self.config.esmfold_config.trunk
  1661. c_s = trunk_config.sequence_state_dim
  1662. c_z = trunk_config.pairwise_state_dim
  1663. self.esm_s_mlp = nn.Sequential(
  1664. LayerNorm(self.esm_feats),
  1665. nn.Linear(self.esm_feats, c_s),
  1666. nn.ReLU(),
  1667. nn.Linear(c_s, c_s),
  1668. )
  1669. # 0 is padding, N is unknown residues, N + 1 is mask.
  1670. self.n_tokens_embed = residue_constants.restype_num + 3
  1671. self.pad_idx = 0
  1672. self.unk_idx = self.n_tokens_embed - 2
  1673. self.mask_idx = self.n_tokens_embed - 1
  1674. self.esm_dict_cls_idx = self.config.vocab_list.index("<cls>")
  1675. self.esm_dict_mask_idx = self.config.vocab_list.index("<mask>")
  1676. self.esm_dict_eos_idx = self.config.vocab_list.index("<eos>")
  1677. self.esm_dict_padding_idx = self.config.vocab_list.index("<pad>")
  1678. if self.config.esmfold_config.embed_aa:
  1679. self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
  1680. self.trunk = EsmFoldingTrunk(trunk_config)
  1681. self.distogram_head = nn.Linear(c_z, self.distogram_bins)
  1682. self.ptm_head = nn.Linear(c_z, self.distogram_bins)
  1683. self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
  1684. self.lddt_bins = 50
  1685. structure_module_config = trunk_config.structure_module
  1686. self.lddt_head = nn.Sequential(
  1687. nn.LayerNorm(structure_module_config.sequence_dim),
  1688. nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1689. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
  1690. nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
  1691. )
  1692. @staticmethod
  1693. def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
  1694. # Remember that t is shifted from residue_constants by 1 (0 is padding).
  1695. esm_reorder = [vocab_list.index("<pad>")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
  1696. return torch.tensor(esm_reorder)
  1697. @add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1698. @replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig)
  1699. def forward(
  1700. self,
  1701. input_ids: torch.Tensor,
  1702. attention_mask: Optional[torch.Tensor] = None,
  1703. position_ids: Optional[torch.Tensor] = None,
  1704. masking_pattern: Optional[torch.Tensor] = None,
  1705. num_recycles: Optional[int] = None,
  1706. ) -> EsmForProteinFoldingOutput:
  1707. r"""
  1708. Returns:
  1709. Example:
  1710. ```python
  1711. >>> from transformers import AutoTokenizer, EsmForProteinFolding
  1712. >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
  1713. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
  1714. >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
  1715. >>> outputs = model(**inputs)
  1716. >>> folded_positions = outputs.positions
  1717. ```
  1718. """
  1719. cfg = self.config.esmfold_config
  1720. aa = input_ids # B x L
  1721. B = aa.shape[0]
  1722. L = aa.shape[1]
  1723. device = input_ids.device
  1724. if attention_mask is None:
  1725. attention_mask = torch.ones_like(aa, device=device)
  1726. if position_ids is None:
  1727. position_ids = torch.arange(L, device=device).expand_as(input_ids)
  1728. # === ESM ===
  1729. esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
  1730. if masking_pattern is not None:
  1731. masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
  1732. else:
  1733. masked_aa = aa
  1734. mlm_targets = None
  1735. # We get sequence and pair representations from whatever version of ESM /
  1736. # configuration we are using. The sequence representation esm_s is always
  1737. # present. The pair embedding esm_z may be present depending on the
  1738. # configuration of the model. If esm_z is not used by the model then it
  1739. # is returned as None here.
  1740. esm_s = self.compute_language_model_representations(esmaa)
  1741. # Convert esm_s and esm_z, if present, to the precision used by the trunk and
  1742. # the structure module. These tensors may be a lower precision if, for example,
  1743. # we're running the language model in fp16 precision.
  1744. esm_s = esm_s.to(self.esm_s_combine.dtype)
  1745. if cfg.esm_ablate_sequence:
  1746. esm_s = esm_s * 0
  1747. esm_s = esm_s.detach()
  1748. # === preprocessing ===
  1749. esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
  1750. s_s_0 = self.esm_s_mlp(esm_s)
  1751. s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
  1752. if self.config.esmfold_config.embed_aa:
  1753. s_s_0 += self.embedding(masked_aa)
  1754. structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
  1755. # Documenting what we expect:
  1756. structure = {
  1757. k: v
  1758. for k, v in structure.items()
  1759. if k
  1760. in [
  1761. "s_z",
  1762. "s_s",
  1763. "frames",
  1764. "sidechain_frames",
  1765. "unnormalized_angles",
  1766. "angles",
  1767. "positions",
  1768. "states",
  1769. ]
  1770. }
  1771. # Add BERT mask for the loss to use, if available.
  1772. if mlm_targets:
  1773. structure["mlm_targets"] = mlm_targets
  1774. disto_logits = self.distogram_head(structure["s_z"])
  1775. disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
  1776. structure["distogram_logits"] = disto_logits
  1777. lm_logits = self.lm_head(structure["s_s"])
  1778. structure["lm_logits"] = lm_logits
  1779. structure["aatype"] = aa
  1780. make_atom14_masks(structure)
  1781. # Of course, this doesn't respect the true mask because it doesn't know about it...
  1782. # We're not going to properly mask change of index tensors:
  1783. # "residx_atom14_to_atom37",
  1784. # "residx_atom37_to_atom14",
  1785. for k in [
  1786. "atom14_atom_exists",
  1787. "atom37_atom_exists",
  1788. ]:
  1789. structure[k] *= attention_mask.unsqueeze(-1)
  1790. structure["residue_index"] = position_ids
  1791. lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
  1792. structure["lddt_head"] = lddt_head
  1793. plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
  1794. structure["plddt"] = plddt
  1795. ptm_logits = self.ptm_head(structure["s_z"])
  1796. structure["ptm_logits"] = ptm_logits
  1797. structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
  1798. structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
  1799. return EsmForProteinFoldingOutput(**structure)
  1800. def af2_idx_to_esm_idx(self, aa, mask):
  1801. # avoid indexing on different devices
  1802. if self.af2_to_esm.device != aa.device:
  1803. self.af2_to_esm = self.af2_to_esm.to(aa.device)
  1804. aa = (aa + 1).masked_fill(mask != 1, 0)
  1805. return self.af2_to_esm[aa]
  1806. def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
  1807. device = next(self.parameters()).device
  1808. B, L = esmaa.shape # B = batch size, L = sequence length.
  1809. if self.config.esmfold_config.bypass_lm:
  1810. esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
  1811. return esm_s
  1812. bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
  1813. bos = esmaa.new_full((B, 1), bosi)
  1814. eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
  1815. esmaa = torch.cat([bos, esmaa, eos], dim=1)
  1816. # Use the first padding index as eos during inference.
  1817. esmaa[range(B), (esmaa != 1).sum(1)] = eosi
  1818. # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
  1819. # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
  1820. # esm_z is always None
  1821. esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
  1822. esm_s = torch.stack(esm_hidden_states, dim=2)
  1823. esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
  1824. return esm_s
  1825. def bert_mask(self, aa, esmaa, mask, pattern):
  1826. new_aa = aa.clone()
  1827. target = aa.clone()
  1828. new_esmaa = esmaa.clone()
  1829. new_aa[pattern == 1] = self.mask_idx
  1830. target[pattern != 1] = 0
  1831. new_esmaa[pattern == 1] = self.esm_dict_mask_idx
  1832. return new_aa, new_esmaa, target
  1833. @torch.no_grad()
  1834. def infer(
  1835. self,
  1836. seqs: Union[str, List[str]],
  1837. position_ids=None,
  1838. ):
  1839. if isinstance(seqs, str):
  1840. lst = [seqs]
  1841. else:
  1842. lst = seqs
  1843. # Returns the raw outputs of the model given an input sequence.
  1844. device = next(self.parameters()).device
  1845. aatype = collate_dense_tensors(
  1846. [
  1847. torch.from_numpy(
  1848. residue_constants.sequence_to_onehot(
  1849. sequence=seq,
  1850. mapping=residue_constants.restype_order_with_x,
  1851. map_unknown_to_x=True,
  1852. )
  1853. )
  1854. .to(device)
  1855. .argmax(dim=1)
  1856. for seq in lst
  1857. ]
  1858. ) # B=1 x L
  1859. mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
  1860. position_ids = (
  1861. torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
  1862. if position_ids is None
  1863. else position_ids.to(device)
  1864. )
  1865. if position_ids.ndim == 1:
  1866. position_ids = position_ids.unsqueeze(0)
  1867. return self.forward(
  1868. aatype,
  1869. mask,
  1870. position_ids=position_ids,
  1871. )
  1872. @staticmethod
  1873. def output_to_pdb(output: Dict) -> List[str]:
  1874. """Returns the pbd (file) string from the model given the model output."""
  1875. output = {k: v.to("cpu").numpy() for k, v in output.items()}
  1876. pdbs = []
  1877. final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
  1878. final_atom_mask = output["atom37_atom_exists"]
  1879. for i in range(output["aatype"].shape[0]):
  1880. aa = output["aatype"][i]
  1881. pred_pos = final_atom_positions[i]
  1882. mask = final_atom_mask[i]
  1883. resid = output["residue_index"][i] + 1
  1884. pred = OFProtein(
  1885. aatype=aa,
  1886. atom_positions=pred_pos,
  1887. atom_mask=mask,
  1888. residue_index=resid,
  1889. b_factors=output["plddt"][i],
  1890. )
  1891. pdbs.append(to_pdb(pred))
  1892. return pdbs
  1893. def infer_pdb(self, seqs, *args, **kwargs) -> str:
  1894. """Returns the pdb (file) string from the model given an input sequence."""
  1895. assert isinstance(seqs, str)
  1896. output = self.infer(seqs, *args, **kwargs)
  1897. return self.output_to_pdb(output)[0]
  1898. def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
  1899. """Returns the pdb (file) string from the model given an input sequence."""
  1900. output = self.infer(seqs, *args, **kwargs)
  1901. return self.output_to_pdb(output)