modeling_mpnet.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch MPNet model."""
  17. import math
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN, gelu
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  34. from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  35. from .configuration_mpnet import MPNetConfig
  36. logger = logging.get_logger(__name__)
  37. _CHECKPOINT_FOR_DOC = "microsoft/mpnet-base"
  38. _CONFIG_FOR_DOC = "MPNetConfig"
  39. class MPNetPreTrainedModel(PreTrainedModel):
  40. config_class = MPNetConfig
  41. base_model_prefix = "mpnet"
  42. def _init_weights(self, module):
  43. """Initialize the weights"""
  44. if isinstance(module, nn.Linear):
  45. # Slightly different from the TF version which uses truncated_normal for initialization
  46. # cf https://github.com/pytorch/pytorch/pull/5617
  47. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  48. if module.bias is not None:
  49. module.bias.data.zero_()
  50. elif isinstance(module, nn.Embedding):
  51. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  52. if module.padding_idx is not None:
  53. module.weight.data[module.padding_idx].zero_()
  54. elif isinstance(module, nn.LayerNorm):
  55. module.bias.data.zero_()
  56. module.weight.data.fill_(1.0)
  57. class MPNetEmbeddings(nn.Module):
  58. def __init__(self, config):
  59. super().__init__()
  60. self.padding_idx = 1
  61. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
  62. self.position_embeddings = nn.Embedding(
  63. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  64. )
  65. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  66. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  67. self.register_buffer(
  68. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  69. )
  70. def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, **kwargs):
  71. if position_ids is None:
  72. if input_ids is not None:
  73. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  74. else:
  75. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  76. if input_ids is not None:
  77. input_shape = input_ids.size()
  78. else:
  79. input_shape = inputs_embeds.size()[:-1]
  80. seq_length = input_shape[1]
  81. if position_ids is None:
  82. position_ids = self.position_ids[:, :seq_length]
  83. if inputs_embeds is None:
  84. inputs_embeds = self.word_embeddings(input_ids)
  85. position_embeddings = self.position_embeddings(position_ids)
  86. embeddings = inputs_embeds + position_embeddings
  87. embeddings = self.LayerNorm(embeddings)
  88. embeddings = self.dropout(embeddings)
  89. return embeddings
  90. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  91. """
  92. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  93. Args:
  94. inputs_embeds: torch.Tensor
  95. Returns: torch.Tensor
  96. """
  97. input_shape = inputs_embeds.size()[:-1]
  98. sequence_length = input_shape[1]
  99. position_ids = torch.arange(
  100. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  101. )
  102. return position_ids.unsqueeze(0).expand(input_shape)
  103. class MPNetSelfAttention(nn.Module):
  104. def __init__(self, config):
  105. super().__init__()
  106. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  107. raise ValueError(
  108. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  109. f"heads ({config.num_attention_heads})"
  110. )
  111. self.num_attention_heads = config.num_attention_heads
  112. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  113. self.all_head_size = self.num_attention_heads * self.attention_head_size
  114. self.q = nn.Linear(config.hidden_size, self.all_head_size)
  115. self.k = nn.Linear(config.hidden_size, self.all_head_size)
  116. self.v = nn.Linear(config.hidden_size, self.all_head_size)
  117. self.o = nn.Linear(config.hidden_size, config.hidden_size)
  118. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  119. def transpose_for_scores(self, x):
  120. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  121. x = x.view(*new_x_shape)
  122. return x.permute(0, 2, 1, 3)
  123. def forward(
  124. self,
  125. hidden_states,
  126. attention_mask=None,
  127. head_mask=None,
  128. position_bias=None,
  129. output_attentions=False,
  130. **kwargs,
  131. ):
  132. q = self.q(hidden_states)
  133. k = self.k(hidden_states)
  134. v = self.v(hidden_states)
  135. q = self.transpose_for_scores(q)
  136. k = self.transpose_for_scores(k)
  137. v = self.transpose_for_scores(v)
  138. # Take the dot product between "query" and "key" to get the raw attention scores.
  139. attention_scores = torch.matmul(q, k.transpose(-1, -2))
  140. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  141. # Apply relative position embedding (precomputed in MPNetEncoder) if provided.
  142. if position_bias is not None:
  143. attention_scores += position_bias
  144. if attention_mask is not None:
  145. attention_scores = attention_scores + attention_mask
  146. # Normalize the attention scores to probabilities.
  147. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  148. attention_probs = self.dropout(attention_probs)
  149. if head_mask is not None:
  150. attention_probs = attention_probs * head_mask
  151. c = torch.matmul(attention_probs, v)
  152. c = c.permute(0, 2, 1, 3).contiguous()
  153. new_c_shape = c.size()[:-2] + (self.all_head_size,)
  154. c = c.view(*new_c_shape)
  155. o = self.o(c)
  156. outputs = (o, attention_probs) if output_attentions else (o,)
  157. return outputs
  158. class MPNetAttention(nn.Module):
  159. def __init__(self, config):
  160. super().__init__()
  161. self.attn = MPNetSelfAttention(config)
  162. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  163. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  164. self.pruned_heads = set()
  165. def prune_heads(self, heads):
  166. if len(heads) == 0:
  167. return
  168. heads, index = find_pruneable_heads_and_indices(
  169. heads, self.attn.num_attention_heads, self.attn.attention_head_size, self.pruned_heads
  170. )
  171. self.attn.q = prune_linear_layer(self.attn.q, index)
  172. self.attn.k = prune_linear_layer(self.attn.k, index)
  173. self.attn.v = prune_linear_layer(self.attn.v, index)
  174. self.attn.o = prune_linear_layer(self.attn.o, index, dim=1)
  175. self.attn.num_attention_heads = self.attn.num_attention_heads - len(heads)
  176. self.attn.all_head_size = self.attn.attention_head_size * self.attn.num_attention_heads
  177. self.pruned_heads = self.pruned_heads.union(heads)
  178. def forward(
  179. self,
  180. hidden_states,
  181. attention_mask=None,
  182. head_mask=None,
  183. position_bias=None,
  184. output_attentions=False,
  185. **kwargs,
  186. ):
  187. self_outputs = self.attn(
  188. hidden_states,
  189. attention_mask,
  190. head_mask,
  191. position_bias,
  192. output_attentions=output_attentions,
  193. )
  194. attention_output = self.LayerNorm(self.dropout(self_outputs[0]) + hidden_states)
  195. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  196. return outputs
  197. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  198. class MPNetIntermediate(nn.Module):
  199. def __init__(self, config):
  200. super().__init__()
  201. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  202. if isinstance(config.hidden_act, str):
  203. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  204. else:
  205. self.intermediate_act_fn = config.hidden_act
  206. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  207. hidden_states = self.dense(hidden_states)
  208. hidden_states = self.intermediate_act_fn(hidden_states)
  209. return hidden_states
  210. # Copied from transformers.models.bert.modeling_bert.BertOutput
  211. class MPNetOutput(nn.Module):
  212. def __init__(self, config):
  213. super().__init__()
  214. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  215. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  216. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  217. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  218. hidden_states = self.dense(hidden_states)
  219. hidden_states = self.dropout(hidden_states)
  220. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  221. return hidden_states
  222. class MPNetLayer(nn.Module):
  223. def __init__(self, config):
  224. super().__init__()
  225. self.attention = MPNetAttention(config)
  226. self.intermediate = MPNetIntermediate(config)
  227. self.output = MPNetOutput(config)
  228. def forward(
  229. self,
  230. hidden_states,
  231. attention_mask=None,
  232. head_mask=None,
  233. position_bias=None,
  234. output_attentions=False,
  235. **kwargs,
  236. ):
  237. self_attention_outputs = self.attention(
  238. hidden_states,
  239. attention_mask,
  240. head_mask,
  241. position_bias=position_bias,
  242. output_attentions=output_attentions,
  243. )
  244. attention_output = self_attention_outputs[0]
  245. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  246. intermediate_output = self.intermediate(attention_output)
  247. layer_output = self.output(intermediate_output, attention_output)
  248. outputs = (layer_output,) + outputs
  249. return outputs
  250. class MPNetEncoder(nn.Module):
  251. def __init__(self, config):
  252. super().__init__()
  253. self.config = config
  254. self.n_heads = config.num_attention_heads
  255. self.layer = nn.ModuleList([MPNetLayer(config) for _ in range(config.num_hidden_layers)])
  256. self.relative_attention_bias = nn.Embedding(config.relative_attention_num_buckets, self.n_heads)
  257. def forward(
  258. self,
  259. hidden_states: torch.Tensor,
  260. attention_mask: Optional[torch.Tensor] = None,
  261. head_mask: Optional[torch.Tensor] = None,
  262. output_attentions: bool = False,
  263. output_hidden_states: bool = False,
  264. return_dict: bool = False,
  265. **kwargs,
  266. ):
  267. position_bias = self.compute_position_bias(hidden_states)
  268. all_hidden_states = () if output_hidden_states else None
  269. all_attentions = () if output_attentions else None
  270. for i, layer_module in enumerate(self.layer):
  271. if output_hidden_states:
  272. all_hidden_states = all_hidden_states + (hidden_states,)
  273. layer_outputs = layer_module(
  274. hidden_states,
  275. attention_mask,
  276. head_mask[i],
  277. position_bias,
  278. output_attentions=output_attentions,
  279. **kwargs,
  280. )
  281. hidden_states = layer_outputs[0]
  282. if output_attentions:
  283. all_attentions = all_attentions + (layer_outputs[1],)
  284. # Add last layer
  285. if output_hidden_states:
  286. all_hidden_states = all_hidden_states + (hidden_states,)
  287. if not return_dict:
  288. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  289. return BaseModelOutput(
  290. last_hidden_state=hidden_states,
  291. hidden_states=all_hidden_states,
  292. attentions=all_attentions,
  293. )
  294. def compute_position_bias(self, x, position_ids=None, num_buckets=32):
  295. bsz, qlen, klen = x.size(0), x.size(1), x.size(1)
  296. if position_ids is not None:
  297. context_position = position_ids[:, :, None]
  298. memory_position = position_ids[:, None, :]
  299. else:
  300. context_position = torch.arange(qlen, dtype=torch.long)[:, None]
  301. memory_position = torch.arange(klen, dtype=torch.long)[None, :]
  302. relative_position = memory_position - context_position
  303. rp_bucket = self.relative_position_bucket(relative_position, num_buckets=num_buckets)
  304. rp_bucket = rp_bucket.to(x.device)
  305. values = self.relative_attention_bias(rp_bucket)
  306. values = values.permute([2, 0, 1]).unsqueeze(0)
  307. values = values.expand((bsz, -1, qlen, klen)).contiguous()
  308. return values
  309. @staticmethod
  310. def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
  311. ret = 0
  312. n = -relative_position
  313. num_buckets //= 2
  314. ret += (n < 0).to(torch.long) * num_buckets
  315. n = torch.abs(n)
  316. max_exact = num_buckets // 2
  317. is_small = n < max_exact
  318. val_if_large = max_exact + (
  319. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  320. ).to(torch.long)
  321. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  322. ret += torch.where(is_small, n, val_if_large)
  323. return ret
  324. # Copied from transformers.models.bert.modeling_bert.BertPooler
  325. class MPNetPooler(nn.Module):
  326. def __init__(self, config):
  327. super().__init__()
  328. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  329. self.activation = nn.Tanh()
  330. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  331. # We "pool" the model by simply taking the hidden state corresponding
  332. # to the first token.
  333. first_token_tensor = hidden_states[:, 0]
  334. pooled_output = self.dense(first_token_tensor)
  335. pooled_output = self.activation(pooled_output)
  336. return pooled_output
  337. MPNET_START_DOCSTRING = r"""
  338. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  339. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  340. etc.)
  341. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  342. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  343. and behavior.
  344. Parameters:
  345. config ([`MPNetConfig`]): Model configuration class with all the parameters of the model.
  346. Initializing with a config file does not load the weights associated with the model, only the
  347. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  348. """
  349. MPNET_INPUTS_DOCSTRING = r"""
  350. Args:
  351. input_ids (`torch.LongTensor` of shape `({0})`):
  352. Indices of input sequence tokens in the vocabulary.
  353. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  354. [`PreTrainedTokenizer.__call__`] for details.
  355. [What are input IDs?](../glossary#input-ids)
  356. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  357. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  358. - 1 for tokens that are **not masked**,
  359. - 0 for tokens that are **masked**.
  360. [What are attention masks?](../glossary#attention-mask)
  361. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  362. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  363. config.max_position_embeddings - 1]`.
  364. [What are position IDs?](../glossary#position-ids)
  365. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  366. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  367. - 1 indicates the head is **not masked**,
  368. - 0 indicates the head is **masked**.
  369. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  370. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  371. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  372. model's internal embedding lookup matrix.
  373. output_attentions (`bool`, *optional*):
  374. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  375. tensors for more detail.
  376. output_hidden_states (`bool`, *optional*):
  377. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  378. more detail.
  379. return_dict (`bool`, *optional*):
  380. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  381. """
  382. @add_start_docstrings(
  383. "The bare MPNet Model transformer outputting raw hidden-states without any specific head on top.",
  384. MPNET_START_DOCSTRING,
  385. )
  386. class MPNetModel(MPNetPreTrainedModel):
  387. def __init__(self, config, add_pooling_layer=True):
  388. super().__init__(config)
  389. self.config = config
  390. self.embeddings = MPNetEmbeddings(config)
  391. self.encoder = MPNetEncoder(config)
  392. self.pooler = MPNetPooler(config) if add_pooling_layer else None
  393. # Initialize weights and apply final processing
  394. self.post_init()
  395. def get_input_embeddings(self):
  396. return self.embeddings.word_embeddings
  397. def set_input_embeddings(self, value):
  398. self.embeddings.word_embeddings = value
  399. def _prune_heads(self, heads_to_prune):
  400. """
  401. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  402. class PreTrainedModel
  403. """
  404. for layer, heads in heads_to_prune.items():
  405. self.encoder.layer[layer].attention.prune_heads(heads)
  406. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  407. @add_code_sample_docstrings(
  408. checkpoint=_CHECKPOINT_FOR_DOC,
  409. output_type=BaseModelOutputWithPooling,
  410. config_class=_CONFIG_FOR_DOC,
  411. )
  412. def forward(
  413. self,
  414. input_ids: Optional[torch.LongTensor] = None,
  415. attention_mask: Optional[torch.FloatTensor] = None,
  416. position_ids: Optional[torch.LongTensor] = None,
  417. head_mask: Optional[torch.FloatTensor] = None,
  418. inputs_embeds: Optional[torch.FloatTensor] = None,
  419. output_attentions: Optional[bool] = None,
  420. output_hidden_states: Optional[bool] = None,
  421. return_dict: Optional[bool] = None,
  422. **kwargs,
  423. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
  424. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  425. output_hidden_states = (
  426. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  427. )
  428. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  429. if input_ids is not None and inputs_embeds is not None:
  430. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  431. elif input_ids is not None:
  432. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  433. input_shape = input_ids.size()
  434. elif inputs_embeds is not None:
  435. input_shape = inputs_embeds.size()[:-1]
  436. else:
  437. raise ValueError("You have to specify either input_ids or inputs_embeds")
  438. device = input_ids.device if input_ids is not None else inputs_embeds.device
  439. if attention_mask is None:
  440. attention_mask = torch.ones(input_shape, device=device)
  441. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  442. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  443. embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
  444. encoder_outputs = self.encoder(
  445. embedding_output,
  446. attention_mask=extended_attention_mask,
  447. head_mask=head_mask,
  448. output_attentions=output_attentions,
  449. output_hidden_states=output_hidden_states,
  450. return_dict=return_dict,
  451. )
  452. sequence_output = encoder_outputs[0]
  453. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  454. if not return_dict:
  455. return (sequence_output, pooled_output) + encoder_outputs[1:]
  456. return BaseModelOutputWithPooling(
  457. last_hidden_state=sequence_output,
  458. pooler_output=pooled_output,
  459. hidden_states=encoder_outputs.hidden_states,
  460. attentions=encoder_outputs.attentions,
  461. )
  462. class MPNetForMaskedLM(MPNetPreTrainedModel):
  463. _tied_weights_keys = ["lm_head.decoder"]
  464. def __init__(self, config):
  465. super().__init__(config)
  466. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  467. self.lm_head = MPNetLMHead(config)
  468. # Initialize weights and apply final processing
  469. self.post_init()
  470. def get_output_embeddings(self):
  471. return self.lm_head.decoder
  472. def set_output_embeddings(self, new_embeddings):
  473. self.lm_head.decoder = new_embeddings
  474. self.lm_head.bias = new_embeddings.bias
  475. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  476. @add_code_sample_docstrings(
  477. checkpoint=_CHECKPOINT_FOR_DOC,
  478. output_type=MaskedLMOutput,
  479. config_class=_CONFIG_FOR_DOC,
  480. )
  481. def forward(
  482. self,
  483. input_ids: Optional[torch.LongTensor] = None,
  484. attention_mask: Optional[torch.FloatTensor] = None,
  485. position_ids: Optional[torch.LongTensor] = None,
  486. head_mask: Optional[torch.FloatTensor] = None,
  487. inputs_embeds: Optional[torch.FloatTensor] = None,
  488. labels: Optional[torch.LongTensor] = None,
  489. output_attentions: Optional[bool] = None,
  490. output_hidden_states: Optional[bool] = None,
  491. return_dict: Optional[bool] = None,
  492. ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
  493. r"""
  494. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  495. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  496. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  497. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  498. """
  499. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  500. outputs = self.mpnet(
  501. input_ids,
  502. attention_mask=attention_mask,
  503. position_ids=position_ids,
  504. head_mask=head_mask,
  505. inputs_embeds=inputs_embeds,
  506. output_attentions=output_attentions,
  507. output_hidden_states=output_hidden_states,
  508. return_dict=return_dict,
  509. )
  510. sequence_output = outputs[0]
  511. prediction_scores = self.lm_head(sequence_output)
  512. masked_lm_loss = None
  513. if labels is not None:
  514. loss_fct = CrossEntropyLoss()
  515. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  516. if not return_dict:
  517. output = (prediction_scores,) + outputs[2:]
  518. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  519. return MaskedLMOutput(
  520. loss=masked_lm_loss,
  521. logits=prediction_scores,
  522. hidden_states=outputs.hidden_states,
  523. attentions=outputs.attentions,
  524. )
  525. class MPNetLMHead(nn.Module):
  526. """MPNet Head for masked and permuted language modeling."""
  527. def __init__(self, config):
  528. super().__init__()
  529. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  530. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  531. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  532. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  533. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  534. self.decoder.bias = self.bias
  535. def _tie_weights(self):
  536. self.decoder.bias = self.bias
  537. def forward(self, features, **kwargs):
  538. x = self.dense(features)
  539. x = gelu(x)
  540. x = self.layer_norm(x)
  541. # project back to size of vocabulary with bias
  542. x = self.decoder(x)
  543. return x
  544. @add_start_docstrings(
  545. """
  546. MPNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  547. output) e.g. for GLUE tasks.
  548. """,
  549. MPNET_START_DOCSTRING,
  550. )
  551. class MPNetForSequenceClassification(MPNetPreTrainedModel):
  552. def __init__(self, config):
  553. super().__init__(config)
  554. self.num_labels = config.num_labels
  555. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  556. self.classifier = MPNetClassificationHead(config)
  557. # Initialize weights and apply final processing
  558. self.post_init()
  559. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  560. @add_code_sample_docstrings(
  561. checkpoint=_CHECKPOINT_FOR_DOC,
  562. output_type=SequenceClassifierOutput,
  563. config_class=_CONFIG_FOR_DOC,
  564. )
  565. def forward(
  566. self,
  567. input_ids: Optional[torch.LongTensor] = None,
  568. attention_mask: Optional[torch.FloatTensor] = None,
  569. position_ids: Optional[torch.LongTensor] = None,
  570. head_mask: Optional[torch.FloatTensor] = None,
  571. inputs_embeds: Optional[torch.FloatTensor] = None,
  572. labels: Optional[torch.LongTensor] = None,
  573. output_attentions: Optional[bool] = None,
  574. output_hidden_states: Optional[bool] = None,
  575. return_dict: Optional[bool] = None,
  576. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
  577. r"""
  578. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  579. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  580. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  581. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  582. """
  583. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  584. outputs = self.mpnet(
  585. input_ids,
  586. attention_mask=attention_mask,
  587. position_ids=position_ids,
  588. head_mask=head_mask,
  589. inputs_embeds=inputs_embeds,
  590. output_attentions=output_attentions,
  591. output_hidden_states=output_hidden_states,
  592. return_dict=return_dict,
  593. )
  594. sequence_output = outputs[0]
  595. logits = self.classifier(sequence_output)
  596. loss = None
  597. if labels is not None:
  598. if self.config.problem_type is None:
  599. if self.num_labels == 1:
  600. self.config.problem_type = "regression"
  601. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  602. self.config.problem_type = "single_label_classification"
  603. else:
  604. self.config.problem_type = "multi_label_classification"
  605. if self.config.problem_type == "regression":
  606. loss_fct = MSELoss()
  607. if self.num_labels == 1:
  608. loss = loss_fct(logits.squeeze(), labels.squeeze())
  609. else:
  610. loss = loss_fct(logits, labels)
  611. elif self.config.problem_type == "single_label_classification":
  612. loss_fct = CrossEntropyLoss()
  613. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  614. elif self.config.problem_type == "multi_label_classification":
  615. loss_fct = BCEWithLogitsLoss()
  616. loss = loss_fct(logits, labels)
  617. if not return_dict:
  618. output = (logits,) + outputs[2:]
  619. return ((loss,) + output) if loss is not None else output
  620. return SequenceClassifierOutput(
  621. loss=loss,
  622. logits=logits,
  623. hidden_states=outputs.hidden_states,
  624. attentions=outputs.attentions,
  625. )
  626. @add_start_docstrings(
  627. """
  628. MPNet Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  629. softmax) e.g. for RocStories/SWAG tasks.
  630. """,
  631. MPNET_START_DOCSTRING,
  632. )
  633. class MPNetForMultipleChoice(MPNetPreTrainedModel):
  634. def __init__(self, config):
  635. super().__init__(config)
  636. self.mpnet = MPNetModel(config)
  637. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  638. self.classifier = nn.Linear(config.hidden_size, 1)
  639. # Initialize weights and apply final processing
  640. self.post_init()
  641. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
  642. @add_code_sample_docstrings(
  643. checkpoint=_CHECKPOINT_FOR_DOC,
  644. output_type=MultipleChoiceModelOutput,
  645. config_class=_CONFIG_FOR_DOC,
  646. )
  647. def forward(
  648. self,
  649. input_ids: Optional[torch.LongTensor] = None,
  650. attention_mask: Optional[torch.FloatTensor] = None,
  651. position_ids: Optional[torch.LongTensor] = None,
  652. head_mask: Optional[torch.FloatTensor] = None,
  653. inputs_embeds: Optional[torch.FloatTensor] = None,
  654. labels: Optional[torch.LongTensor] = None,
  655. output_attentions: Optional[bool] = None,
  656. output_hidden_states: Optional[bool] = None,
  657. return_dict: Optional[bool] = None,
  658. ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
  659. r"""
  660. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  661. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  662. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  663. `input_ids` above)
  664. """
  665. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  666. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  667. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  668. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  669. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  670. flat_inputs_embeds = (
  671. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  672. if inputs_embeds is not None
  673. else None
  674. )
  675. outputs = self.mpnet(
  676. flat_input_ids,
  677. position_ids=flat_position_ids,
  678. attention_mask=flat_attention_mask,
  679. head_mask=head_mask,
  680. inputs_embeds=flat_inputs_embeds,
  681. output_attentions=output_attentions,
  682. output_hidden_states=output_hidden_states,
  683. return_dict=return_dict,
  684. )
  685. pooled_output = outputs[1]
  686. pooled_output = self.dropout(pooled_output)
  687. logits = self.classifier(pooled_output)
  688. reshaped_logits = logits.view(-1, num_choices)
  689. loss = None
  690. if labels is not None:
  691. loss_fct = CrossEntropyLoss()
  692. loss = loss_fct(reshaped_logits, labels)
  693. if not return_dict:
  694. output = (reshaped_logits,) + outputs[2:]
  695. return ((loss,) + output) if loss is not None else output
  696. return MultipleChoiceModelOutput(
  697. loss=loss,
  698. logits=reshaped_logits,
  699. hidden_states=outputs.hidden_states,
  700. attentions=outputs.attentions,
  701. )
  702. @add_start_docstrings(
  703. """
  704. MPNet Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  705. Named-Entity-Recognition (NER) tasks.
  706. """,
  707. MPNET_START_DOCSTRING,
  708. )
  709. class MPNetForTokenClassification(MPNetPreTrainedModel):
  710. def __init__(self, config):
  711. super().__init__(config)
  712. self.num_labels = config.num_labels
  713. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  714. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  715. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  716. # Initialize weights and apply final processing
  717. self.post_init()
  718. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  719. @add_code_sample_docstrings(
  720. checkpoint=_CHECKPOINT_FOR_DOC,
  721. output_type=TokenClassifierOutput,
  722. config_class=_CONFIG_FOR_DOC,
  723. )
  724. def forward(
  725. self,
  726. input_ids: Optional[torch.LongTensor] = None,
  727. attention_mask: Optional[torch.FloatTensor] = None,
  728. position_ids: Optional[torch.LongTensor] = None,
  729. head_mask: Optional[torch.FloatTensor] = None,
  730. inputs_embeds: Optional[torch.FloatTensor] = None,
  731. labels: Optional[torch.LongTensor] = None,
  732. output_attentions: Optional[bool] = None,
  733. output_hidden_states: Optional[bool] = None,
  734. return_dict: Optional[bool] = None,
  735. ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
  736. r"""
  737. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  738. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  739. """
  740. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  741. outputs = self.mpnet(
  742. input_ids,
  743. attention_mask=attention_mask,
  744. position_ids=position_ids,
  745. head_mask=head_mask,
  746. inputs_embeds=inputs_embeds,
  747. output_attentions=output_attentions,
  748. output_hidden_states=output_hidden_states,
  749. return_dict=return_dict,
  750. )
  751. sequence_output = outputs[0]
  752. sequence_output = self.dropout(sequence_output)
  753. logits = self.classifier(sequence_output)
  754. loss = None
  755. if labels is not None:
  756. loss_fct = CrossEntropyLoss()
  757. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  758. if not return_dict:
  759. output = (logits,) + outputs[2:]
  760. return ((loss,) + output) if loss is not None else output
  761. return TokenClassifierOutput(
  762. loss=loss,
  763. logits=logits,
  764. hidden_states=outputs.hidden_states,
  765. attentions=outputs.attentions,
  766. )
  767. class MPNetClassificationHead(nn.Module):
  768. """Head for sentence-level classification tasks."""
  769. def __init__(self, config):
  770. super().__init__()
  771. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  772. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  773. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  774. def forward(self, features, **kwargs):
  775. x = features[:, 0, :] # take <s> token (equiv. to BERT's [CLS] token)
  776. x = self.dropout(x)
  777. x = self.dense(x)
  778. x = torch.tanh(x)
  779. x = self.dropout(x)
  780. x = self.out_proj(x)
  781. return x
  782. @add_start_docstrings(
  783. """
  784. MPNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  785. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  786. """,
  787. MPNET_START_DOCSTRING,
  788. )
  789. class MPNetForQuestionAnswering(MPNetPreTrainedModel):
  790. def __init__(self, config):
  791. super().__init__(config)
  792. self.num_labels = config.num_labels
  793. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  794. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  795. # Initialize weights and apply final processing
  796. self.post_init()
  797. @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  798. @add_code_sample_docstrings(
  799. checkpoint=_CHECKPOINT_FOR_DOC,
  800. output_type=QuestionAnsweringModelOutput,
  801. config_class=_CONFIG_FOR_DOC,
  802. )
  803. def forward(
  804. self,
  805. input_ids: Optional[torch.LongTensor] = None,
  806. attention_mask: Optional[torch.FloatTensor] = None,
  807. position_ids: Optional[torch.LongTensor] = None,
  808. head_mask: Optional[torch.FloatTensor] = None,
  809. inputs_embeds: Optional[torch.FloatTensor] = None,
  810. start_positions: Optional[torch.LongTensor] = None,
  811. end_positions: Optional[torch.LongTensor] = None,
  812. output_attentions: Optional[bool] = None,
  813. output_hidden_states: Optional[bool] = None,
  814. return_dict: Optional[bool] = None,
  815. ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  816. r"""
  817. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  818. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  819. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  820. are not taken into account for computing the loss.
  821. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  822. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  823. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  824. are not taken into account for computing the loss.
  825. """
  826. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  827. outputs = self.mpnet(
  828. input_ids,
  829. attention_mask=attention_mask,
  830. position_ids=position_ids,
  831. head_mask=head_mask,
  832. inputs_embeds=inputs_embeds,
  833. output_attentions=output_attentions,
  834. output_hidden_states=output_hidden_states,
  835. return_dict=return_dict,
  836. )
  837. sequence_output = outputs[0]
  838. logits = self.qa_outputs(sequence_output)
  839. start_logits, end_logits = logits.split(1, dim=-1)
  840. start_logits = start_logits.squeeze(-1).contiguous()
  841. end_logits = end_logits.squeeze(-1).contiguous()
  842. total_loss = None
  843. if start_positions is not None and end_positions is not None:
  844. # If we are on multi-GPU, split add a dimension
  845. if len(start_positions.size()) > 1:
  846. start_positions = start_positions.squeeze(-1)
  847. if len(end_positions.size()) > 1:
  848. end_positions = end_positions.squeeze(-1)
  849. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  850. ignored_index = start_logits.size(1)
  851. start_positions = start_positions.clamp(0, ignored_index)
  852. end_positions = end_positions.clamp(0, ignored_index)
  853. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  854. start_loss = loss_fct(start_logits, start_positions)
  855. end_loss = loss_fct(end_logits, end_positions)
  856. total_loss = (start_loss + end_loss) / 2
  857. if not return_dict:
  858. output = (start_logits, end_logits) + outputs[2:]
  859. return ((total_loss,) + output) if total_loss is not None else output
  860. return QuestionAnsweringModelOutput(
  861. loss=total_loss,
  862. start_logits=start_logits,
  863. end_logits=end_logits,
  864. hidden_states=outputs.hidden_states,
  865. attentions=outputs.attentions,
  866. )
  867. def create_position_ids_from_input_ids(input_ids, padding_idx):
  868. """
  869. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  870. are ignored. This is modified from fairseq's `utils.make_positions`. :param torch.Tensor x: :return torch.Tensor:
  871. """
  872. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  873. mask = input_ids.ne(padding_idx).int()
  874. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  875. return incremental_indices.long() + padding_idx