modeling_esm.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262
  1. # coding=utf-8
  2. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ESM model."""
  16. import math
  17. from typing import List, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPastAndCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. MaskedLMOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
  31. from ...utils import logging
  32. from .configuration_esm import EsmConfig
  33. logger = logging.get_logger(__name__)
  34. _CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
  35. _CONFIG_FOR_DOC = "EsmConfig"
  36. def rotate_half(x):
  37. x1, x2 = x.chunk(2, dim=-1)
  38. return torch.cat((-x2, x1), dim=-1)
  39. def apply_rotary_pos_emb(x, cos, sin):
  40. cos = cos[:, :, : x.shape[-2], :]
  41. sin = sin[:, :, : x.shape[-2], :]
  42. return (x * cos) + (rotate_half(x) * sin)
  43. def gelu(x):
  44. """
  45. This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
  46. """
  47. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  48. def symmetrize(x):
  49. "Make layer symmetric in final two dimensions, used for contact prediction."
  50. return x + x.transpose(-1, -2)
  51. def average_product_correct(x):
  52. "Perform average product correct, used for contact prediction."
  53. a1 = x.sum(-1, keepdims=True)
  54. a2 = x.sum(-2, keepdims=True)
  55. a12 = x.sum((-1, -2), keepdims=True)
  56. avg = a1 * a2
  57. avg.div_(a12) # in-place to reduce memory
  58. normalized = x - avg
  59. return normalized
  60. class RotaryEmbedding(torch.nn.Module):
  61. """
  62. Rotary position embeddings based on those in
  63. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  64. matrices which depend on their relative positions.
  65. """
  66. def __init__(self, dim: int):
  67. super().__init__()
  68. # Generate and save the inverse frequency buffer (non trainable)
  69. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  70. inv_freq = inv_freq
  71. self.register_buffer("inv_freq", inv_freq)
  72. self._seq_len_cached = None
  73. self._cos_cached = None
  74. self._sin_cached = None
  75. def _update_cos_sin_tables(self, x, seq_dimension=2):
  76. seq_len = x.shape[seq_dimension]
  77. # Reset the tables if the sequence length has changed,
  78. # or if we're on a new device (possibly due to tracing for instance)
  79. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  80. self._seq_len_cached = seq_len
  81. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  82. freqs = torch.outer(t, self.inv_freq)
  83. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  84. self._cos_cached = emb.cos()[None, None, :, :]
  85. self._sin_cached = emb.sin()[None, None, :, :]
  86. return self._cos_cached, self._sin_cached
  87. def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  88. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  89. return (
  90. apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
  91. apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
  92. )
  93. class EsmContactPredictionHead(nn.Module):
  94. """Performs symmetrization, apc, and computes a logistic regression on the output features"""
  95. def __init__(
  96. self,
  97. in_features: int,
  98. bias=True,
  99. eos_idx: int = 2,
  100. ):
  101. super().__init__()
  102. self.in_features = in_features
  103. self.eos_idx = eos_idx
  104. self.regression = nn.Linear(in_features, 1, bias)
  105. self.activation = nn.Sigmoid()
  106. def forward(self, tokens, attentions):
  107. # remove eos token attentions
  108. eos_mask = tokens.ne(self.eos_idx).to(attentions)
  109. eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
  110. attentions = attentions * eos_mask[:, None, None, :, :]
  111. attentions = attentions[..., :-1, :-1]
  112. # remove cls token attentions
  113. attentions = attentions[..., 1:, 1:]
  114. batch_size, layers, heads, seqlen, _ = attentions.size()
  115. attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
  116. # features: batch x channels x tokens x tokens (symmetric)
  117. attentions = attentions.to(
  118. self.regression.weight.device
  119. ) # attentions always float32, may need to convert to float16
  120. attentions = average_product_correct(symmetrize(attentions))
  121. attentions = attentions.permute(0, 2, 3, 1)
  122. return self.activation(self.regression(attentions).squeeze(3))
  123. class EsmEmbeddings(nn.Module):
  124. """
  125. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  126. """
  127. def __init__(self, config):
  128. super().__init__()
  129. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  130. if config.emb_layer_norm_before:
  131. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  132. else:
  133. self.layer_norm = None
  134. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  135. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  136. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  137. self.register_buffer(
  138. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  139. )
  140. self.padding_idx = config.pad_token_id
  141. self.position_embeddings = nn.Embedding(
  142. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  143. )
  144. self.token_dropout = config.token_dropout
  145. self.mask_token_id = config.mask_token_id
  146. def forward(
  147. self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  148. ):
  149. if position_ids is None:
  150. if input_ids is not None:
  151. # Create the position ids from the input token ids. Any padded tokens remain padded.
  152. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  153. else:
  154. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  155. if inputs_embeds is None:
  156. inputs_embeds = self.word_embeddings(input_ids)
  157. # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
  158. # embedding_scale factor here.
  159. embeddings = inputs_embeds
  160. # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
  161. # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
  162. # masked tokens are treated as if they were selected for input dropout and zeroed out.
  163. # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
  164. # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
  165. # This is analogous to the way that dropout layers scale down outputs during evaluation when not
  166. # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
  167. if self.token_dropout:
  168. embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
  169. mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
  170. src_lengths = attention_mask.sum(-1)
  171. mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
  172. embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
  173. embeddings.dtype
  174. )
  175. if self.position_embedding_type == "absolute":
  176. position_embeddings = self.position_embeddings(position_ids)
  177. embeddings = embeddings + position_embeddings
  178. if self.layer_norm is not None:
  179. embeddings = self.layer_norm(embeddings)
  180. if attention_mask is not None:
  181. embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
  182. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
  183. # embeddings = self.dropout(embeddings)
  184. return embeddings
  185. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  186. """
  187. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  188. Args:
  189. inputs_embeds: torch.Tensor
  190. Returns: torch.Tensor
  191. """
  192. input_shape = inputs_embeds.size()[:-1]
  193. sequence_length = input_shape[1]
  194. position_ids = torch.arange(
  195. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  196. )
  197. return position_ids.unsqueeze(0).expand(input_shape)
  198. class EsmSelfAttention(nn.Module):
  199. def __init__(self, config, position_embedding_type=None):
  200. super().__init__()
  201. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  202. raise ValueError(
  203. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  204. f"heads ({config.num_attention_heads})"
  205. )
  206. self.num_attention_heads = config.num_attention_heads
  207. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  208. self.all_head_size = self.num_attention_heads * self.attention_head_size
  209. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  210. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  211. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  212. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  213. self.position_embedding_type = position_embedding_type or getattr(
  214. config, "position_embedding_type", "absolute"
  215. )
  216. self.rotary_embeddings = None
  217. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  218. self.max_position_embeddings = config.max_position_embeddings
  219. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  220. elif self.position_embedding_type == "rotary":
  221. self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
  222. self.is_decoder = config.is_decoder
  223. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  224. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  225. x = x.view(new_x_shape)
  226. return x.permute(0, 2, 1, 3)
  227. def forward(
  228. self,
  229. hidden_states: torch.Tensor,
  230. attention_mask: Optional[torch.FloatTensor] = None,
  231. head_mask: Optional[torch.FloatTensor] = None,
  232. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  233. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  234. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  235. output_attentions: Optional[bool] = False,
  236. ) -> Tuple[torch.Tensor]:
  237. mixed_query_layer = self.query(hidden_states)
  238. # If this is instantiated as a cross-attention module, the keys
  239. # and values come from an encoder; the attention mask needs to be
  240. # such that the encoder's padding tokens are not attended to.
  241. is_cross_attention = encoder_hidden_states is not None
  242. if is_cross_attention and past_key_value is not None:
  243. # reuse k,v, cross_attentions
  244. key_layer = past_key_value[0]
  245. value_layer = past_key_value[1]
  246. attention_mask = encoder_attention_mask
  247. elif is_cross_attention:
  248. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  249. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  250. attention_mask = encoder_attention_mask
  251. elif past_key_value is not None:
  252. key_layer = self.transpose_for_scores(self.key(hidden_states))
  253. value_layer = self.transpose_for_scores(self.value(hidden_states))
  254. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  255. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  256. else:
  257. key_layer = self.transpose_for_scores(self.key(hidden_states))
  258. value_layer = self.transpose_for_scores(self.value(hidden_states))
  259. query_layer = self.transpose_for_scores(mixed_query_layer)
  260. # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
  261. # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
  262. # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
  263. # ESM code and fix rotary embeddings.
  264. query_layer = query_layer * self.attention_head_size**-0.5
  265. if self.is_decoder:
  266. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  267. # Further calls to cross_attention layer can then reuse all cross-attention
  268. # key/value_states (first "if" case)
  269. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  270. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  271. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  272. # if encoder bi-directional self-attention `past_key_value` is always `None`
  273. past_key_value = (key_layer, value_layer)
  274. if self.position_embedding_type == "rotary":
  275. query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
  276. # Take the dot product between "query" and "key" to get the raw attention scores.
  277. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  278. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  279. seq_length = hidden_states.size()[1]
  280. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  281. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  282. distance = position_ids_l - position_ids_r
  283. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  284. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  285. if self.position_embedding_type == "relative_key":
  286. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  287. attention_scores = attention_scores + relative_position_scores
  288. elif self.position_embedding_type == "relative_key_query":
  289. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  290. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  291. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  292. if attention_mask is not None:
  293. # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
  294. attention_scores = attention_scores + attention_mask
  295. # Normalize the attention scores to probabilities.
  296. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  297. # This is actually dropping out entire tokens to attend to, which might
  298. # seem a bit unusual, but is taken from the original Transformer paper.
  299. attention_probs = self.dropout(attention_probs)
  300. # Mask heads if we want to
  301. if head_mask is not None:
  302. attention_probs = attention_probs * head_mask
  303. context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
  304. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  305. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  306. context_layer = context_layer.view(new_context_layer_shape)
  307. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  308. if self.is_decoder:
  309. outputs = outputs + (past_key_value,)
  310. return outputs
  311. class EsmSelfOutput(nn.Module):
  312. def __init__(self, config):
  313. super().__init__()
  314. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  315. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  316. def forward(self, hidden_states, input_tensor):
  317. hidden_states = self.dense(hidden_states)
  318. hidden_states = self.dropout(hidden_states)
  319. hidden_states = hidden_states + input_tensor
  320. return hidden_states
  321. class EsmAttention(nn.Module):
  322. def __init__(self, config):
  323. super().__init__()
  324. self.self = EsmSelfAttention(config)
  325. self.output = EsmSelfOutput(config)
  326. self.pruned_heads = set()
  327. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  328. def prune_heads(self, heads):
  329. if len(heads) == 0:
  330. return
  331. heads, index = find_pruneable_heads_and_indices(
  332. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  333. )
  334. # Prune linear layers
  335. self.self.query = prune_linear_layer(self.self.query, index)
  336. self.self.key = prune_linear_layer(self.self.key, index)
  337. self.self.value = prune_linear_layer(self.self.value, index)
  338. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  339. # Update hyper params and store pruned heads
  340. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  341. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  342. self.pruned_heads = self.pruned_heads.union(heads)
  343. def forward(
  344. self,
  345. hidden_states,
  346. attention_mask=None,
  347. head_mask=None,
  348. encoder_hidden_states=None,
  349. encoder_attention_mask=None,
  350. past_key_value=None,
  351. output_attentions=False,
  352. ):
  353. hidden_states_ln = self.LayerNorm(hidden_states)
  354. self_outputs = self.self(
  355. hidden_states_ln,
  356. attention_mask,
  357. head_mask,
  358. encoder_hidden_states,
  359. encoder_attention_mask,
  360. past_key_value,
  361. output_attentions,
  362. )
  363. attention_output = self.output(self_outputs[0], hidden_states)
  364. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  365. return outputs
  366. class EsmIntermediate(nn.Module):
  367. def __init__(self, config):
  368. super().__init__()
  369. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  370. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  371. hidden_states = self.dense(hidden_states)
  372. hidden_states = gelu(hidden_states)
  373. return hidden_states
  374. class EsmOutput(nn.Module):
  375. def __init__(self, config):
  376. super().__init__()
  377. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  378. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  379. def forward(self, hidden_states, input_tensor):
  380. hidden_states = self.dense(hidden_states)
  381. hidden_states = self.dropout(hidden_states)
  382. hidden_states = hidden_states + input_tensor
  383. return hidden_states
  384. class EsmLayer(nn.Module):
  385. def __init__(self, config):
  386. super().__init__()
  387. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  388. self.seq_len_dim = 1
  389. self.attention = EsmAttention(config)
  390. self.is_decoder = config.is_decoder
  391. self.add_cross_attention = config.add_cross_attention
  392. if self.add_cross_attention:
  393. if not self.is_decoder:
  394. raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
  395. self.crossattention = EsmAttention(config)
  396. self.intermediate = EsmIntermediate(config)
  397. self.output = EsmOutput(config)
  398. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  399. def forward(
  400. self,
  401. hidden_states,
  402. attention_mask=None,
  403. head_mask=None,
  404. encoder_hidden_states=None,
  405. encoder_attention_mask=None,
  406. past_key_value=None,
  407. output_attentions=False,
  408. ):
  409. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  410. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  411. self_attention_outputs = self.attention(
  412. hidden_states,
  413. attention_mask,
  414. head_mask,
  415. output_attentions=output_attentions,
  416. past_key_value=self_attn_past_key_value,
  417. )
  418. attention_output = self_attention_outputs[0]
  419. # if decoder, the last output is tuple of self-attn cache
  420. if self.is_decoder:
  421. outputs = self_attention_outputs[1:-1]
  422. present_key_value = self_attention_outputs[-1]
  423. else:
  424. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  425. cross_attn_present_key_value = None
  426. if self.is_decoder and encoder_hidden_states is not None:
  427. if not hasattr(self, "crossattention"):
  428. raise AttributeError(
  429. f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
  430. " with cross-attention layers by setting `config.add_cross_attention=True`"
  431. )
  432. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  433. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  434. cross_attention_outputs = self.crossattention(
  435. attention_output,
  436. attention_mask,
  437. head_mask,
  438. encoder_hidden_states,
  439. encoder_attention_mask,
  440. cross_attn_past_key_value,
  441. output_attentions,
  442. )
  443. attention_output = cross_attention_outputs[0]
  444. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  445. # add cross-attn cache to positions 3,4 of present_key_value tuple
  446. cross_attn_present_key_value = cross_attention_outputs[-1]
  447. present_key_value = present_key_value + cross_attn_present_key_value
  448. layer_output = self.feed_forward_chunk(attention_output)
  449. outputs = (layer_output,) + outputs
  450. # if decoder, return the attn key/values as the last output
  451. if self.is_decoder:
  452. outputs = outputs + (present_key_value,)
  453. return outputs
  454. def feed_forward_chunk(self, attention_output):
  455. attention_output_ln = self.LayerNorm(attention_output)
  456. intermediate_output = self.intermediate(attention_output_ln)
  457. layer_output = self.output(intermediate_output, attention_output)
  458. return layer_output
  459. class EsmEncoder(nn.Module):
  460. def __init__(self, config):
  461. super().__init__()
  462. self.config = config
  463. self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
  464. self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  465. self.gradient_checkpointing = False
  466. def forward(
  467. self,
  468. hidden_states,
  469. attention_mask=None,
  470. head_mask=None,
  471. encoder_hidden_states=None,
  472. encoder_attention_mask=None,
  473. past_key_values=None,
  474. use_cache=None,
  475. output_attentions=False,
  476. output_hidden_states=False,
  477. return_dict=True,
  478. ):
  479. if self.gradient_checkpointing and self.training:
  480. if use_cache:
  481. logger.warning_once(
  482. "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
  483. "`use_cache=False`..."
  484. )
  485. use_cache = False
  486. all_hidden_states = () if output_hidden_states else None
  487. all_self_attentions = () if output_attentions else None
  488. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  489. next_decoder_cache = () if use_cache else None
  490. for i, layer_module in enumerate(self.layer):
  491. if output_hidden_states:
  492. all_hidden_states = all_hidden_states + (hidden_states,)
  493. layer_head_mask = head_mask[i] if head_mask is not None else None
  494. past_key_value = past_key_values[i] if past_key_values is not None else None
  495. if self.gradient_checkpointing and self.training:
  496. layer_outputs = self._gradient_checkpointing_func(
  497. layer_module.__call__,
  498. hidden_states,
  499. attention_mask,
  500. layer_head_mask,
  501. encoder_hidden_states,
  502. encoder_attention_mask,
  503. past_key_value,
  504. output_attentions,
  505. )
  506. else:
  507. layer_outputs = layer_module(
  508. hidden_states,
  509. attention_mask,
  510. layer_head_mask,
  511. encoder_hidden_states,
  512. encoder_attention_mask,
  513. past_key_value,
  514. output_attentions,
  515. )
  516. hidden_states = layer_outputs[0]
  517. if use_cache:
  518. next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
  519. if output_attentions:
  520. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  521. if self.config.add_cross_attention:
  522. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  523. if self.emb_layer_norm_after:
  524. hidden_states = self.emb_layer_norm_after(hidden_states)
  525. if output_hidden_states:
  526. all_hidden_states = all_hidden_states + (hidden_states,)
  527. if not return_dict:
  528. return tuple(
  529. v
  530. for v in [
  531. hidden_states,
  532. next_decoder_cache,
  533. all_hidden_states,
  534. all_self_attentions,
  535. all_cross_attentions,
  536. ]
  537. if v is not None
  538. )
  539. return BaseModelOutputWithPastAndCrossAttentions(
  540. last_hidden_state=hidden_states,
  541. past_key_values=next_decoder_cache,
  542. hidden_states=all_hidden_states,
  543. attentions=all_self_attentions,
  544. cross_attentions=all_cross_attentions,
  545. )
  546. # Copied from transformers.models.bert.modeling_bert.BertPooler
  547. class EsmPooler(nn.Module):
  548. def __init__(self, config):
  549. super().__init__()
  550. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  551. self.activation = nn.Tanh()
  552. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  553. # We "pool" the model by simply taking the hidden state corresponding
  554. # to the first token.
  555. first_token_tensor = hidden_states[:, 0]
  556. pooled_output = self.dense(first_token_tensor)
  557. pooled_output = self.activation(pooled_output)
  558. return pooled_output
  559. class EsmPreTrainedModel(PreTrainedModel):
  560. """
  561. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  562. models.
  563. """
  564. config_class = EsmConfig
  565. base_model_prefix = "esm"
  566. supports_gradient_checkpointing = True
  567. _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
  568. # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
  569. def _init_weights(self, module):
  570. """Initialize the weights"""
  571. if isinstance(module, nn.Linear):
  572. # Slightly different from the TF version which uses truncated_normal for initialization
  573. # cf https://github.com/pytorch/pytorch/pull/5617
  574. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  575. if module.bias is not None:
  576. module.bias.data.zero_()
  577. elif isinstance(module, nn.Embedding):
  578. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  579. if module.padding_idx is not None:
  580. module.weight.data[module.padding_idx].zero_()
  581. elif isinstance(module, nn.LayerNorm):
  582. module.bias.data.zero_()
  583. module.weight.data.fill_(1.0)
  584. ESM_START_DOCSTRING = r"""
  585. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  586. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  587. etc.)
  588. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  589. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  590. and behavior.
  591. Parameters:
  592. config ([`EsmConfig`]): Model configuration class with all the parameters of the
  593. model. Initializing with a config file does not load the weights associated with the model, only the
  594. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  595. """
  596. ESM_INPUTS_DOCSTRING = r"""
  597. Args:
  598. input_ids (`torch.LongTensor` of shape `({0})`):
  599. Indices of input sequence tokens in the vocabulary.
  600. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  601. [`PreTrainedTokenizer.__call__`] for details.
  602. [What are input IDs?](../glossary#input-ids)
  603. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  604. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  605. - 1 for tokens that are **not masked**,
  606. - 0 for tokens that are **masked**.
  607. [What are attention masks?](../glossary#attention-mask)
  608. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  609. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  610. config.max_position_embeddings - 1]`.
  611. [What are position IDs?](../glossary#position-ids)
  612. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  613. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  614. - 1 indicates the head is **not masked**,
  615. - 0 indicates the head is **masked**.
  616. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  617. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  618. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  619. model's internal embedding lookup matrix.
  620. output_attentions (`bool`, *optional*):
  621. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  622. tensors for more detail.
  623. output_hidden_states (`bool`, *optional*):
  624. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  625. more detail.
  626. return_dict (`bool`, *optional*):
  627. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  628. """
  629. @add_start_docstrings(
  630. "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
  631. ESM_START_DOCSTRING,
  632. )
  633. class EsmModel(EsmPreTrainedModel):
  634. """
  635. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  636. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  637. all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  638. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  639. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  640. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  641. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  642. """
  643. def __init__(self, config, add_pooling_layer=True):
  644. super().__init__(config)
  645. self.config = config
  646. self.embeddings = EsmEmbeddings(config)
  647. self.encoder = EsmEncoder(config)
  648. self.pooler = EsmPooler(config) if add_pooling_layer else None
  649. self.contact_head = EsmContactPredictionHead(
  650. in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
  651. )
  652. # Initialize weights and apply final processing
  653. self.post_init()
  654. def get_input_embeddings(self):
  655. return self.embeddings.word_embeddings
  656. def set_input_embeddings(self, value):
  657. self.embeddings.word_embeddings = value
  658. def _prune_heads(self, heads_to_prune):
  659. """
  660. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  661. class PreTrainedModel
  662. """
  663. for layer, heads in heads_to_prune.items():
  664. self.encoder.layer[layer].attention.prune_heads(heads)
  665. @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
  666. @add_code_sample_docstrings(
  667. checkpoint=_CHECKPOINT_FOR_DOC,
  668. output_type=BaseModelOutputWithPoolingAndCrossAttentions,
  669. config_class=_CONFIG_FOR_DOC,
  670. )
  671. def forward(
  672. self,
  673. input_ids: Optional[torch.Tensor] = None,
  674. attention_mask: Optional[torch.Tensor] = None,
  675. position_ids: Optional[torch.Tensor] = None,
  676. head_mask: Optional[torch.Tensor] = None,
  677. inputs_embeds: Optional[torch.Tensor] = None,
  678. encoder_hidden_states: Optional[torch.Tensor] = None,
  679. encoder_attention_mask: Optional[torch.Tensor] = None,
  680. past_key_values: Optional[List[torch.FloatTensor]] = None,
  681. use_cache: Optional[bool] = None,
  682. output_attentions: Optional[bool] = None,
  683. output_hidden_states: Optional[bool] = None,
  684. return_dict: Optional[bool] = None,
  685. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  686. r"""
  687. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  688. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  689. the model is configured as a decoder.
  690. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  691. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  692. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  693. - 1 for tokens that are **not masked**,
  694. - 0 for tokens that are **masked**.
  695. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  696. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  697. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  698. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  699. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  700. use_cache (`bool`, *optional*):
  701. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  702. `past_key_values`).
  703. """
  704. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  705. output_hidden_states = (
  706. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  707. )
  708. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  709. if self.config.is_decoder:
  710. use_cache = use_cache if use_cache is not None else self.config.use_cache
  711. else:
  712. use_cache = False
  713. if input_ids is not None and inputs_embeds is not None:
  714. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  715. elif input_ids is not None:
  716. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  717. input_shape = input_ids.size()
  718. elif inputs_embeds is not None:
  719. input_shape = inputs_embeds.size()[:-1]
  720. else:
  721. raise ValueError("You have to specify either input_ids or inputs_embeds")
  722. batch_size, seq_length = input_shape
  723. device = input_ids.device if input_ids is not None else inputs_embeds.device
  724. # past_key_values_length
  725. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  726. if attention_mask is None:
  727. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  728. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  729. # ourselves in which case we just need to make it broadcastable to all heads.
  730. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  731. # If a 2D or 3D attention mask is provided for the cross-attention
  732. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  733. if self.config.is_decoder and encoder_hidden_states is not None:
  734. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  735. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  736. if encoder_attention_mask is None:
  737. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  738. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  739. else:
  740. encoder_extended_attention_mask = None
  741. # Prepare head mask if needed
  742. # 1.0 in head_mask indicate we keep the head
  743. # attention_probs has shape bsz x n_heads x N x N
  744. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  745. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  746. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  747. embedding_output = self.embeddings(
  748. input_ids=input_ids,
  749. position_ids=position_ids,
  750. attention_mask=attention_mask,
  751. inputs_embeds=inputs_embeds,
  752. past_key_values_length=past_key_values_length,
  753. )
  754. encoder_outputs = self.encoder(
  755. embedding_output,
  756. attention_mask=extended_attention_mask,
  757. head_mask=head_mask,
  758. encoder_hidden_states=encoder_hidden_states,
  759. encoder_attention_mask=encoder_extended_attention_mask,
  760. past_key_values=past_key_values,
  761. use_cache=use_cache,
  762. output_attentions=output_attentions,
  763. output_hidden_states=output_hidden_states,
  764. return_dict=return_dict,
  765. )
  766. sequence_output = encoder_outputs[0]
  767. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  768. if not return_dict:
  769. return (sequence_output, pooled_output) + encoder_outputs[1:]
  770. return BaseModelOutputWithPoolingAndCrossAttentions(
  771. last_hidden_state=sequence_output,
  772. pooler_output=pooled_output,
  773. past_key_values=encoder_outputs.past_key_values,
  774. hidden_states=encoder_outputs.hidden_states,
  775. attentions=encoder_outputs.attentions,
  776. cross_attentions=encoder_outputs.cross_attentions,
  777. )
  778. def predict_contacts(self, tokens, attention_mask):
  779. attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
  780. attns = torch.stack(attns, dim=1) # Matches the original model layout
  781. # In the original model, attentions for padding tokens are completely zeroed out.
  782. # This makes no difference most of the time because the other tokens won't attend to them,
  783. # but it does for the contact prediction task, which takes attentions as input,
  784. # so we have to mimic that here.
  785. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
  786. attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
  787. return self.contact_head(tokens, attns)
  788. @add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
  789. class EsmForMaskedLM(EsmPreTrainedModel):
  790. _tied_weights_keys = ["lm_head.decoder.weight"]
  791. def __init__(self, config):
  792. super().__init__(config)
  793. if config.is_decoder:
  794. logger.warning(
  795. "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
  796. "bi-directional self-attention."
  797. )
  798. self.esm = EsmModel(config, add_pooling_layer=False)
  799. self.lm_head = EsmLMHead(config)
  800. self.init_weights()
  801. def get_output_embeddings(self):
  802. return self.lm_head.decoder
  803. def set_output_embeddings(self, new_embeddings):
  804. self.lm_head.decoder = new_embeddings
  805. @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  806. @add_code_sample_docstrings(
  807. checkpoint=_CHECKPOINT_FOR_DOC,
  808. output_type=MaskedLMOutput,
  809. config_class=_CONFIG_FOR_DOC,
  810. mask="<mask>",
  811. )
  812. def forward(
  813. self,
  814. input_ids: Optional[torch.LongTensor] = None,
  815. attention_mask: Optional[torch.Tensor] = None,
  816. position_ids: Optional[torch.LongTensor] = None,
  817. head_mask: Optional[torch.Tensor] = None,
  818. inputs_embeds: Optional[torch.FloatTensor] = None,
  819. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  820. encoder_attention_mask: Optional[torch.Tensor] = None,
  821. labels: Optional[torch.LongTensor] = None,
  822. output_attentions: Optional[bool] = None,
  823. output_hidden_states: Optional[bool] = None,
  824. return_dict: Optional[bool] = None,
  825. ) -> Union[Tuple, MaskedLMOutput]:
  826. r"""
  827. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  828. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  829. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  830. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  831. kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
  832. Used to hide legacy arguments that have been deprecated.
  833. """
  834. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  835. outputs = self.esm(
  836. input_ids,
  837. attention_mask=attention_mask,
  838. position_ids=position_ids,
  839. head_mask=head_mask,
  840. inputs_embeds=inputs_embeds,
  841. encoder_hidden_states=encoder_hidden_states,
  842. encoder_attention_mask=encoder_attention_mask,
  843. output_attentions=output_attentions,
  844. output_hidden_states=output_hidden_states,
  845. return_dict=return_dict,
  846. )
  847. sequence_output = outputs[0]
  848. prediction_scores = self.lm_head(sequence_output)
  849. masked_lm_loss = None
  850. if labels is not None:
  851. loss_fct = CrossEntropyLoss()
  852. labels = labels.to(prediction_scores.device)
  853. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  854. if not return_dict:
  855. output = (prediction_scores,) + outputs[2:]
  856. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  857. return MaskedLMOutput(
  858. loss=masked_lm_loss,
  859. logits=prediction_scores,
  860. hidden_states=outputs.hidden_states,
  861. attentions=outputs.attentions,
  862. )
  863. def predict_contacts(self, tokens, attention_mask):
  864. return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
  865. class EsmLMHead(nn.Module):
  866. """ESM Head for masked language modeling."""
  867. def __init__(self, config):
  868. super().__init__()
  869. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  870. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  871. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  872. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  873. def forward(self, features, **kwargs):
  874. x = self.dense(features)
  875. x = gelu(x)
  876. x = self.layer_norm(x)
  877. # project back to size of vocabulary with bias
  878. x = self.decoder(x) + self.bias
  879. return x
  880. @add_start_docstrings(
  881. """
  882. ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  883. output) e.g. for GLUE tasks.
  884. """,
  885. ESM_START_DOCSTRING,
  886. )
  887. class EsmForSequenceClassification(EsmPreTrainedModel):
  888. def __init__(self, config):
  889. super().__init__(config)
  890. self.num_labels = config.num_labels
  891. self.config = config
  892. self.esm = EsmModel(config, add_pooling_layer=False)
  893. self.classifier = EsmClassificationHead(config)
  894. self.init_weights()
  895. @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  896. @add_code_sample_docstrings(
  897. checkpoint=_CHECKPOINT_FOR_DOC,
  898. output_type=SequenceClassifierOutput,
  899. config_class=_CONFIG_FOR_DOC,
  900. )
  901. def forward(
  902. self,
  903. input_ids: Optional[torch.LongTensor] = None,
  904. attention_mask: Optional[torch.Tensor] = None,
  905. position_ids: Optional[torch.LongTensor] = None,
  906. head_mask: Optional[torch.Tensor] = None,
  907. inputs_embeds: Optional[torch.FloatTensor] = None,
  908. labels: Optional[torch.LongTensor] = None,
  909. output_attentions: Optional[bool] = None,
  910. output_hidden_states: Optional[bool] = None,
  911. return_dict: Optional[bool] = None,
  912. ) -> Union[Tuple, SequenceClassifierOutput]:
  913. r"""
  914. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  915. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  916. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  917. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  918. """
  919. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  920. outputs = self.esm(
  921. input_ids,
  922. attention_mask=attention_mask,
  923. position_ids=position_ids,
  924. head_mask=head_mask,
  925. inputs_embeds=inputs_embeds,
  926. output_attentions=output_attentions,
  927. output_hidden_states=output_hidden_states,
  928. return_dict=return_dict,
  929. )
  930. sequence_output = outputs[0]
  931. logits = self.classifier(sequence_output)
  932. loss = None
  933. if labels is not None:
  934. labels = labels.to(logits.device)
  935. if self.config.problem_type is None:
  936. if self.num_labels == 1:
  937. self.config.problem_type = "regression"
  938. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  939. self.config.problem_type = "single_label_classification"
  940. else:
  941. self.config.problem_type = "multi_label_classification"
  942. if self.config.problem_type == "regression":
  943. loss_fct = MSELoss()
  944. if self.num_labels == 1:
  945. loss = loss_fct(logits.squeeze(), labels.squeeze())
  946. else:
  947. loss = loss_fct(logits, labels)
  948. elif self.config.problem_type == "single_label_classification":
  949. loss_fct = CrossEntropyLoss()
  950. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  951. elif self.config.problem_type == "multi_label_classification":
  952. loss_fct = BCEWithLogitsLoss()
  953. loss = loss_fct(logits, labels)
  954. if not return_dict:
  955. output = (logits,) + outputs[2:]
  956. return ((loss,) + output) if loss is not None else output
  957. return SequenceClassifierOutput(
  958. loss=loss,
  959. logits=logits,
  960. hidden_states=outputs.hidden_states,
  961. attentions=outputs.attentions,
  962. )
  963. @add_start_docstrings(
  964. """
  965. ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  966. Named-Entity-Recognition (NER) tasks.
  967. """,
  968. ESM_START_DOCSTRING,
  969. )
  970. class EsmForTokenClassification(EsmPreTrainedModel):
  971. def __init__(self, config):
  972. super().__init__(config)
  973. self.num_labels = config.num_labels
  974. self.esm = EsmModel(config, add_pooling_layer=False)
  975. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  976. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  977. self.init_weights()
  978. @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  979. @add_code_sample_docstrings(
  980. checkpoint=_CHECKPOINT_FOR_DOC,
  981. output_type=TokenClassifierOutput,
  982. config_class=_CONFIG_FOR_DOC,
  983. )
  984. def forward(
  985. self,
  986. input_ids: Optional[torch.LongTensor] = None,
  987. attention_mask: Optional[torch.Tensor] = None,
  988. position_ids: Optional[torch.LongTensor] = None,
  989. head_mask: Optional[torch.Tensor] = None,
  990. inputs_embeds: Optional[torch.FloatTensor] = None,
  991. labels: Optional[torch.LongTensor] = None,
  992. output_attentions: Optional[bool] = None,
  993. output_hidden_states: Optional[bool] = None,
  994. return_dict: Optional[bool] = None,
  995. ) -> Union[Tuple, TokenClassifierOutput]:
  996. r"""
  997. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  998. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  999. """
  1000. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1001. outputs = self.esm(
  1002. input_ids,
  1003. attention_mask=attention_mask,
  1004. position_ids=position_ids,
  1005. head_mask=head_mask,
  1006. inputs_embeds=inputs_embeds,
  1007. output_attentions=output_attentions,
  1008. output_hidden_states=output_hidden_states,
  1009. return_dict=return_dict,
  1010. )
  1011. sequence_output = outputs[0]
  1012. sequence_output = self.dropout(sequence_output)
  1013. logits = self.classifier(sequence_output)
  1014. loss = None
  1015. if labels is not None:
  1016. loss_fct = CrossEntropyLoss()
  1017. labels = labels.to(logits.device)
  1018. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1019. if not return_dict:
  1020. output = (logits,) + outputs[2:]
  1021. return ((loss,) + output) if loss is not None else output
  1022. return TokenClassifierOutput(
  1023. loss=loss,
  1024. logits=logits,
  1025. hidden_states=outputs.hidden_states,
  1026. attentions=outputs.attentions,
  1027. )
  1028. class EsmClassificationHead(nn.Module):
  1029. """Head for sentence-level classification tasks."""
  1030. def __init__(self, config):
  1031. super().__init__()
  1032. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1033. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1034. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  1035. def forward(self, features, **kwargs):
  1036. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  1037. x = self.dropout(x)
  1038. x = self.dense(x)
  1039. x = torch.tanh(x)
  1040. x = self.dropout(x)
  1041. x = self.out_proj(x)
  1042. return x
  1043. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  1044. """
  1045. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  1046. are ignored. This is modified from fairseq's `utils.make_positions`.
  1047. Args:
  1048. x: torch.Tensor x:
  1049. Returns: torch.Tensor
  1050. """
  1051. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1052. mask = input_ids.ne(padding_idx).int()
  1053. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1054. return incremental_indices.long() + padding_idx