modeling_prophetnet.py 112 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311
  1. # coding=utf-8
  2. # Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version)."""
  16. import copy
  17. import math
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Tuple, Union
  21. import torch
  22. import torch.utils.checkpoint
  23. from torch import Tensor, nn
  24. from torch.nn import LayerNorm
  25. from ...activations import ACT2FN
  26. from ...generation import GenerationMixin
  27. from ...modeling_outputs import BaseModelOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...utils import (
  30. ModelOutput,
  31. add_start_docstrings,
  32. add_start_docstrings_to_model_forward,
  33. logging,
  34. replace_return_docstrings,
  35. )
  36. from .configuration_prophetnet import ProphetNetConfig
  37. logger = logging.get_logger(__name__)
  38. _CONFIG_FOR_DOC = "ProphenetConfig"
  39. _CHECKPOINT_FOR_DOC = "microsoft/prophetnet-large-uncased"
  40. PROPHETNET_START_DOCSTRING = r"""
  41. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  42. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  43. etc.)
  44. Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted
  45. from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the
  46. file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`.
  47. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  48. it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
  49. behavior.
  50. Parameters:
  51. config ([`ProphetNetConfig`]): Model configuration class with all the parameters of the model.
  52. Initializing with a config file does not load the weights associated with the model, only the
  53. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  54. """
  55. PROPHETNET_INPUTS_DOCSTRING = r"""
  56. Args:
  57. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  58. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  59. it.
  60. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  61. [`PreTrainedTokenizer.__call__`] for details.
  62. [What are input IDs?](../glossary#input-ids)
  63. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  64. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  65. - 1 for tokens that are **not masked**,
  66. - 0 for tokens that are **masked**.
  67. [What are attention masks?](../glossary#attention-mask)
  68. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  69. Indices of decoder input sequence tokens in the vocabulary.
  70. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  71. [`PreTrainedTokenizer.__call__`] for details.
  72. [What are decoder input IDs?](../glossary#decoder-input-ids)
  73. ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  74. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  75. `past_key_values`).
  76. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  77. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  78. be used by default.
  79. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  80. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  81. - 1 indicates the head is **not masked**,
  82. - 0 indicates the head is **masked**.
  83. decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  84. Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
  85. - 1 indicates the head is **not masked**,
  86. - 0 indicates the head is **masked**.
  87. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  88. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  89. - 1 indicates the head is **not masked**,
  90. - 0 indicates the head is **masked**.
  91. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  92. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  93. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  94. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  95. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  96. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
  97. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  98. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  99. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  100. use_cache (`bool`, *optional*):
  101. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  102. `past_key_values`).
  103. output_attentions (`bool`, *optional*):
  104. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  105. tensors for more detail.
  106. output_hidden_states (`bool`, *optional*):
  107. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  108. more detail.
  109. return_dict (`bool`, *optional*):
  110. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  111. """
  112. PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r"""
  113. Args:
  114. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  115. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  116. it.
  117. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  118. [`PreTrainedTokenizer.__call__`] for details.
  119. [What are input IDs?](../glossary#input-ids)
  120. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  121. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  122. - 1 for tokens that are **not masked**,
  123. - 0 for tokens that are **masked**.
  124. [What are attention masks?](../glossary#attention-mask)
  125. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  126. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  127. - 1 indicates the head is **not masked**,
  128. - 0 indicates the head is **masked**.
  129. output_attentions (`bool`, *optional*):
  130. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  131. tensors for more detail.
  132. output_hidden_states (`bool`, *optional*):
  133. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  134. more detail.
  135. return_dict (`bool`, *optional*):
  136. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  137. """
  138. def softmax(hidden_state, dim, onnx_trace=False):
  139. if onnx_trace:
  140. return nn.functional.softmax(hidden_state.float(), dim=dim)
  141. else:
  142. return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)
  143. def ngram_attention_bias(sequence_length, ngram, device, dtype):
  144. """
  145. This function computes the bias for the predict stream
  146. """
  147. left_block = (
  148. torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
  149. )
  150. right_block = left_block.detach().clone()
  151. # create bias
  152. for stream_idx in range(ngram):
  153. right_block[stream_idx].fill_diagonal_(0, wrap=False)
  154. left_block[stream_idx].triu_(-stream_idx + 1)
  155. left_block[:, :, 0] = 0
  156. return torch.cat([left_block, right_block], dim=2)
  157. def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
  158. """
  159. This function computes individual parts of the relative position buckets. For more detail, see paper.
  160. """
  161. inv_relative_positions = -relative_positions
  162. rel_positions_bucket = 0
  163. if is_bidirectional:
  164. num_buckets = num_buckets // 2
  165. rel_positions_bucket = (
  166. rel_positions_bucket
  167. + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets
  168. )
  169. inv_relative_positions = torch.abs(inv_relative_positions)
  170. else:
  171. inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions))
  172. max_exact = num_buckets // 2
  173. is_small = torch.lt(inv_relative_positions, max_exact)
  174. val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log(
  175. max_distance / max_exact
  176. ) * (num_buckets - max_exact)
  177. val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int()
  178. rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large)
  179. return rel_positions_bucket
  180. def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
  181. """
  182. This function computes both main and predict relative position buckets. For more detail, see paper.
  183. """
  184. # main stream
  185. main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1)
  186. main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1)
  187. # predicting stream
  188. predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1)
  189. predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1)
  190. predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1)
  191. # get both position buckets
  192. main_relative_position_buckets = compute_relative_buckets(
  193. num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
  194. )
  195. predict_relative_position_buckets = compute_relative_buckets(
  196. num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
  197. )
  198. return main_relative_position_buckets, predict_relative_position_buckets
  199. @dataclass
  200. class ProphetNetSeq2SeqLMOutput(ModelOutput):
  201. """
  202. Base class for sequence-to-sequence language models outputs.
  203. Args:
  204. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  205. Language modeling loss.
  206. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  207. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  208. SoftMax).
  209. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  210. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  211. SoftMax).
  212. past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  213. List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
  214. num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
  215. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  216. used (see `past_key_values` input) to speed up sequential decoding.
  217. decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  218. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  219. shape `(batch_size, decoder_sequence_length, hidden_size)`.
  220. Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
  221. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  222. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  223. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  224. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  225. outputs.
  226. decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  227. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  228. decoder_sequence_length, decoder_sequence_length)`.
  229. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  230. self-attention heads.
  231. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  232. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  233. decoder_sequence_length, decoder_sequence_length)`.
  234. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  235. weighted average in the self-attention heads.
  236. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  237. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  238. encoder_sequence_length, decoder_sequence_length)`.
  239. Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
  240. compute the weighted average in the
  241. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  242. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  243. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  244. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  245. shape `(batch_size, encoder_sequence_length, hidden_size)`.
  246. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  247. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  248. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  249. encoder_sequence_length, encoder_sequence_length)`. Attentions weights of the encoder, after the attention
  250. softmax, used to compute the weighted average in the self-attention heads.
  251. """
  252. loss: Optional[torch.FloatTensor] = None
  253. logits: torch.FloatTensor = None
  254. logits_ngram: Optional[torch.FloatTensor] = None
  255. past_key_values: Optional[Tuple[torch.FloatTensor]] = None
  256. decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  257. decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  258. decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
  259. decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
  260. cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
  261. encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  262. encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  263. encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
  264. @property
  265. def decoder_cross_attentions(self):
  266. warnings.warn(
  267. "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
  268. " instead.",
  269. FutureWarning,
  270. )
  271. return self.cross_attentions
  272. @dataclass
  273. class ProphetNetSeq2SeqModelOutput(ModelOutput):
  274. """
  275. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  276. decoding.
  277. Args:
  278. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  279. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  280. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  281. hidden_size)` is output.
  282. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
  283. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  284. past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  285. List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
  286. num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
  287. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  288. used (see `past_key_values` input) to speed up sequential decoding.
  289. decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  290. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  291. shape `(batch_size, decoder_sequence_length, hidden_size)`.
  292. Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
  293. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  294. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  295. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  296. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  297. outputs.
  298. decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  299. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  300. decoder_sequence_length, decoder_sequence_length)`.
  301. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  302. self-attention heads.
  303. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  304. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  305. decoder_sequence_length, decoder_sequence_length)`.
  306. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  307. weighted average in the
  308. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  309. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  310. encoder_sequence_length, decoder_sequence_length)`.
  311. Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
  312. compute the weighted average in the
  313. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  314. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  315. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  316. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  317. shape `(batch_size, encoder_sequence_length, hidden_size)`.
  318. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  319. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  320. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  321. encoder_sequence_length, encoder_sequence_length)`.
  322. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  323. self-attention heads.
  324. """
  325. last_hidden_state: torch.FloatTensor
  326. last_hidden_state_ngram: Optional[torch.FloatTensor] = None
  327. past_key_values: Optional[Tuple[torch.FloatTensor]] = None
  328. decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  329. decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  330. decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
  331. decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
  332. cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
  333. encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  334. encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  335. encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
  336. @property
  337. def decoder_cross_attentions(self):
  338. warnings.warn(
  339. "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
  340. " instead.",
  341. FutureWarning,
  342. )
  343. return self.cross_attentions
  344. @dataclass
  345. class ProphetNetDecoderModelOutput(ModelOutput):
  346. """
  347. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  348. Args:
  349. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  350. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  351. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  352. hidden_size)` is output.
  353. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  354. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  355. past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  356. List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
  357. num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
  358. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  359. used (see `past_key_values` input) to speed up sequential decoding.
  360. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  361. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  362. shape `(batch_size, decoder_sequence_length, hidden_size)`.
  363. Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
  364. ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  365. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  366. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  367. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  368. outputs.
  369. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  370. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  371. decoder_sequence_length, decoder_sequence_length)`.
  372. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  373. self-attention heads.
  374. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  375. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  376. decoder_sequence_length, decoder_sequence_length)`.
  377. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  378. weighted average in the
  379. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  380. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  381. encoder_sequence_length, decoder_sequence_length)`.
  382. Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
  383. compute the weighted average in the
  384. """
  385. last_hidden_state: torch.FloatTensor
  386. last_hidden_state_ngram: Optional[torch.FloatTensor] = None
  387. past_key_values: Optional[Tuple[torch.FloatTensor]] = None
  388. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  389. hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
  390. attentions: Optional[Tuple[torch.FloatTensor]] = None
  391. ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
  392. cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
  393. @dataclass
  394. class ProphetNetDecoderLMOutput(ModelOutput):
  395. """
  396. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  397. Args:
  398. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  399. Language modeling loss.
  400. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  401. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  402. SoftMax).
  403. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  404. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  405. SoftMax).
  406. past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  407. List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
  408. num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
  409. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  410. used (see `past_key_values` input) to speed up sequential decoding.
  411. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  412. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  413. shape `(batch_size, decoder_sequence_length, hidden_size)`.
  414. Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
  415. ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  416. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  417. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  418. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  419. outputs.
  420. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  421. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  422. decoder_sequence_length, decoder_sequence_length)`.
  423. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  424. self-attention heads.
  425. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  426. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  427. decoder_sequence_length, decoder_sequence_length)`.
  428. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  429. weighted average in the
  430. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  431. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  432. encoder_sequence_length, decoder_sequence_length)`.
  433. Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
  434. compute the weighted average in the
  435. """
  436. loss: Optional[torch.FloatTensor] = None
  437. logits: torch.FloatTensor = None
  438. logits_ngram: Optional[torch.FloatTensor] = None
  439. past_key_values: Optional[Tuple[torch.FloatTensor]] = None
  440. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  441. hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
  442. attentions: Optional[Tuple[torch.FloatTensor]] = None
  443. ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
  444. cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
  445. class ProphetNetPreTrainedModel(PreTrainedModel):
  446. config_class = ProphetNetConfig
  447. base_model_prefix = "prophetnet"
  448. supports_gradient_checkpointing = True
  449. def _init_weights(self, module):
  450. if isinstance(module, nn.Linear):
  451. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  452. if module.bias is not None:
  453. module.bias.data.zero_()
  454. elif isinstance(module, nn.Embedding):
  455. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  456. if module.padding_idx is not None:
  457. module.weight.data[module.padding_idx].zero_()
  458. def _shift_right(self, input_ids):
  459. decoder_start_token_id = self.config.decoder_start_token_id
  460. pad_token_id = self.config.pad_token_id
  461. assert decoder_start_token_id is not None, (
  462. "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the"
  463. " pad_token_id. See ProphetNet docs for more information"
  464. )
  465. # shift inputs to the right
  466. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  467. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  468. shifted_input_ids[..., 0] = decoder_start_token_id
  469. assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
  470. # replace possible -100 values in labels by `pad_token_id`
  471. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  472. assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
  473. return shifted_input_ids
  474. class ProphetNetPositionalEmbeddings(nn.Embedding):
  475. """
  476. This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
  477. based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
  478. the forward function.
  479. """
  480. def __init__(self, config: ProphetNetConfig) -> None:
  481. self.max_length = config.max_position_embeddings
  482. super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
  483. def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
  484. assert (position_ids is None) or (
  485. self.padding_idx is None
  486. ), "If position_ids is pre-computed then padding_idx should not be set."
  487. if position_ids is None:
  488. if past_key_values is not None:
  489. # position_ids is the same for every token when decoding a single step
  490. # Without the int() cast, it doesn't work in some cases when exporting to ONNX
  491. prev_num_input_ids = past_key_values[0][0].shape[2]
  492. num_input_ids = inputs_shape[1] + prev_num_input_ids
  493. position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
  494. int(self.padding_idx + num_input_ids)
  495. )
  496. else:
  497. if attention_mask is None:
  498. attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device)
  499. # retrieve position_ids from input_ids / attention_mask
  500. position_ids = (
  501. torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
  502. ).long() + self.padding_idx
  503. # make sure position_ids are not bigger then max_length
  504. position_ids = position_ids.clamp(0, self.max_length - 1)
  505. return super().forward(position_ids), position_ids
  506. def _forward(self, position_ids):
  507. return super().forward(position_ids)
  508. class ProphetNetAttention(nn.Module):
  509. """Multi-headed attention from 'Attention Is All You Need' paper"""
  510. def __init__(
  511. self,
  512. config: ProphetNetConfig,
  513. num_attn_heads: int,
  514. ):
  515. super().__init__()
  516. hidden_size = config.hidden_size
  517. self.attention_dropout = config.attention_dropout
  518. self.dropout = config.dropout
  519. self.num_attn_heads = num_attn_heads
  520. self.head_dim = hidden_size // num_attn_heads
  521. assert self.head_dim * num_attn_heads == hidden_size, (
  522. "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
  523. " `config.num_decoder_attention_heads`"
  524. )
  525. self.key_proj = nn.Linear(hidden_size, hidden_size)
  526. self.value_proj = nn.Linear(hidden_size, hidden_size)
  527. self.query_proj = nn.Linear(hidden_size, hidden_size)
  528. self.out_proj = nn.Linear(hidden_size, hidden_size)
  529. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  530. return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
  531. def forward(
  532. self,
  533. hidden_states,
  534. key_value_states: Optional[Tensor] = None,
  535. attention_mask: Optional[Tensor] = None,
  536. layer_head_mask: Optional[Tensor] = None,
  537. past_key_value: Optional[Tuple[Tensor]] = None,
  538. output_attentions: bool = False,
  539. ) -> Tuple[Tensor, Optional[Tensor]]:
  540. batch_size, tgt_len, hidden_size = hidden_states.size()
  541. # if key_value_states are provided this layer is used as a cross-attention layer
  542. # for the decoder
  543. is_cross_attention = key_value_states is not None
  544. assert list(hidden_states.size()) == [
  545. batch_size,
  546. tgt_len,
  547. hidden_size,
  548. ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
  549. # previous time steps are cached - no need to recompute key and value if they are static
  550. query_states = self.query_proj(hidden_states) / (self.head_dim**0.5)
  551. if is_cross_attention and past_key_value is not None:
  552. # reuse k,v, cross_attentions
  553. key_states = past_key_value[0]
  554. value_states = past_key_value[1]
  555. elif is_cross_attention:
  556. # cross_attentions
  557. key_states = self._shape(self.key_proj(key_value_states), -1, batch_size)
  558. value_states = self._shape(self.value_proj(key_value_states), -1, batch_size)
  559. else:
  560. # self_attention
  561. key_states = self._shape(self.key_proj(hidden_states), -1, batch_size)
  562. value_states = self._shape(self.value_proj(hidden_states), -1, batch_size)
  563. if is_cross_attention:
  564. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  565. # Further calls to cross_attention layer can then reuse all cross-attention
  566. # key/value_states (first "if" case)
  567. # if encoder bi-directional self-attention `past_key_value` is always `None`
  568. past_key_value = (key_states, value_states)
  569. # project states into the correct shape
  570. proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
  571. query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
  572. key_states = key_states.view(*proj_shape)
  573. value_states = value_states.view(*proj_shape)
  574. src_len = key_states.size(2)
  575. attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
  576. expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
  577. if attn_weights.size() != expected_shape:
  578. raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
  579. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
  580. if attention_mask is not None and attention_mask.dim() == 0:
  581. attention_mask = None
  582. expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
  583. if attention_mask is not None and attention_mask.size() != expected_shape:
  584. raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
  585. if attention_mask is not None: # don't attend to padding symbols
  586. attn_weights = attn_weights + attention_mask
  587. if output_attentions:
  588. attn_weights_reshaped = attn_weights
  589. else:
  590. attn_weights_reshaped = None
  591. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  592. if layer_head_mask is not None:
  593. assert layer_head_mask.size() == (self.num_attn_heads,), (
  594. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  595. f" {layer_head_mask.size()}"
  596. )
  597. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
  598. batch_size, self.num_attn_heads, tgt_len, src_len
  599. )
  600. # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
  601. attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
  602. attn_probs = nn.functional.dropout(
  603. attn_weights,
  604. p=self.attention_dropout,
  605. training=self.training,
  606. )
  607. attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
  608. expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
  609. if attn_output.size() != expected_shape:
  610. raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
  611. attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
  612. attn_output = self.out_proj(attn_output)
  613. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  614. return attn_output, attn_weights_reshaped, past_key_value
  615. class ProphetNetFeedForward(nn.Module):
  616. """
  617. This is the residual two feed-forward layer block based on the original Transformer implementation.
  618. """
  619. def __init__(self, config: ProphetNetConfig, ffn_dim: int):
  620. super().__init__()
  621. self.activation_fn = ACT2FN[config.activation_function]
  622. self.intermediate = nn.Linear(config.hidden_size, ffn_dim)
  623. self.output = nn.Linear(ffn_dim, config.hidden_size)
  624. self.activation_dropout = config.activation_dropout
  625. self.dropout = config.dropout
  626. def forward(self, hidden_states):
  627. hidden_states = self.intermediate(hidden_states)
  628. hidden_states = self.activation_fn(hidden_states)
  629. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  630. hidden_states = self.output(hidden_states)
  631. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  632. return hidden_states
  633. class ProphetNetNgramSelfAttention(nn.Module):
  634. def __init__(self, config: ProphetNetConfig):
  635. super().__init__()
  636. self.hidden_size = config.hidden_size
  637. self.num_buckets = config.num_buckets
  638. self.relative_max_distance = config.relative_max_distance
  639. self.num_attn_heads = config.num_decoder_attention_heads
  640. self.dropout = config.dropout
  641. self.attention_dropout = config.attention_dropout
  642. self.head_dim = config.hidden_size // self.num_attn_heads
  643. self.ngram = config.ngram
  644. assert (
  645. self.head_dim * self.num_attn_heads == config.hidden_size
  646. ), "config.hidden_size must be divisible by num_attn_heads"
  647. # key, value, query projection
  648. self.key_proj = nn.Linear(config.hidden_size, config.hidden_size)
  649. self.value_proj = nn.Linear(config.hidden_size, config.hidden_size)
  650. self.query_proj = nn.Linear(config.hidden_size, config.hidden_size)
  651. # out projection
  652. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
  653. # rel position embeddings
  654. self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads)
  655. # for onnx runtime
  656. self.onnx_trace = False
  657. def _shape(self, tensor, seq_len, batch_size):
  658. return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
  659. def prepare_for_onnx_export_(self):
  660. self.onnx_trace = True
  661. def forward(
  662. self,
  663. hidden_states,
  664. past_key_value: Optional[Tuple[Tensor]] = None,
  665. attention_mask=None,
  666. layer_head_mask=None,
  667. extended_predict_attention_mask=None,
  668. main_relative_position_buckets=None,
  669. predict_relative_position_buckets=None,
  670. position_ids=None,
  671. ):
  672. batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
  673. assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
  674. f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
  675. f" {hidden_states.shape}"
  676. )
  677. # project
  678. query_states = self.query_proj(hidden_states)
  679. key_states = self.key_proj(hidden_states)
  680. value_states = self.value_proj(hidden_states)
  681. # normalize
  682. query_states = query_states / (self.head_dim**0.5)
  683. # reshape
  684. query_states = self._shape(query_states, ngram_sequence_length, batch_size)
  685. key_states = self._shape(key_states, -1, batch_size)
  686. value_states = self._shape(value_states, -1, batch_size)
  687. proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
  688. query_states = query_states.view(*proj_shape)
  689. key_states = key_states.view(*proj_shape)
  690. value_states = value_states.view(*proj_shape)
  691. # chunk into main stream and predict stream
  692. hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
  693. query_states_list = query_states.chunk(1 + self.ngram, dim=2)
  694. key_states_list = key_states.chunk(1 + self.ngram, dim=2)
  695. value_states_list = value_states.chunk(1 + self.ngram, dim=2)
  696. main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
  697. main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
  698. main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:]
  699. main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
  700. # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
  701. if past_key_value is not None:
  702. prev_main_key_states = past_key_value[0]
  703. main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
  704. prev_main_value_states = past_key_value[1]
  705. main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
  706. # Update cache
  707. past_key_value = (main_key_states, main_value_states)
  708. # get seq_length of main stream only
  709. sequence_length = ngram_sequence_length // (1 + self.ngram)
  710. # MAIN-STREAM
  711. # main attn weights
  712. # [batch_size, number_heads, sequence_length, head_dimesion]
  713. # x [batch_size, number_heads, head_dimesion, sequence_length]
  714. # -> [batch_size, number_heads, sequence_length, sequence_length]
  715. main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
  716. # retrieve relative position embeddings for each layer -> see paper for more details
  717. main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
  718. main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
  719. )
  720. main_attn_weights = main_attn_weights + main_relative_pos_embeddings
  721. if attention_mask is not None:
  722. main_attn_weights = main_attn_weights + attention_mask
  723. main_attn_probs = softmax(
  724. main_attn_weights,
  725. dim=-1,
  726. onnx_trace=self.onnx_trace,
  727. ).type_as(main_attn_weights)
  728. if layer_head_mask is not None:
  729. assert layer_head_mask.size() == (self.num_attn_heads,), (
  730. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  731. f" {layer_head_mask.size()}"
  732. )
  733. main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
  734. batch_size, self.num_attn_heads, -1, sequence_length
  735. )
  736. main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
  737. # project to attn_output
  738. # [batch_size, number_heads, sequence_length, sequence_length]
  739. # x [batch_size, number_heads, sequence_length, head_dimesion]
  740. # -> [batch_size, number_heads, sequence_length, head_dimesion]
  741. main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
  742. # reshape so that num_heads dim is merged into last `head_dim` axis
  743. main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
  744. main_attn_output = self.out_proj(main_attn_output)
  745. # PREDICT-STREAM
  746. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  747. predict_query_states = torch.stack(predict_query_states_list, 1).view(
  748. batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
  749. )
  750. # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  751. predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
  752. # [batch_size, sequence_length, ngram, hidden_size]
  753. predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
  754. # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
  755. predict_value_states = torch.cat(
  756. [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
  757. )
  758. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  759. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  760. # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  761. predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
  762. # retrieve relative position embeddings for each layer -> see paper for more details
  763. # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
  764. predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
  765. predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
  766. )
  767. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  768. predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
  769. if extended_predict_attention_mask is not None:
  770. # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  771. extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
  772. extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
  773. predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
  774. predict_attn_probs = softmax(
  775. predict_attn_weights,
  776. dim=-1,
  777. onnx_trace=self.onnx_trace,
  778. ).type_as(predict_attn_weights)
  779. if layer_head_mask is not None:
  780. assert layer_head_mask.size() == (self.num_attn_heads,), (
  781. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  782. f" {layer_head_mask.size()}"
  783. )
  784. predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
  785. predict_attn_probs = nn.functional.dropout(
  786. predict_attn_probs, p=self.attention_dropout, training=self.training
  787. )
  788. # project to attention output
  789. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  790. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  791. # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  792. predict_attn_output = torch.einsum(
  793. "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
  794. )
  795. # reshape so that num_heads dim is merged into last `head_dim` axis
  796. # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
  797. predict_attn_output = predict_attn_output.transpose(2, 3)
  798. predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
  799. predict_attn_output = self.out_proj(predict_attn_output)
  800. # concat to single attn output
  801. # [batch_size, (1+ngram)*sequence_length, hidden_size]
  802. attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
  803. # reshape into better form for `config.output_attentions`
  804. main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
  805. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  806. return attn_output, main_attn_probs, predict_attn_probs, past_key_value
  807. def get_main_relative_pos_embeddings(
  808. self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
  809. ):
  810. # input hidden_states [batch_size, sequence_length, hidden_size]
  811. # input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
  812. # input position_ids [batch_size, sequence_length] or [1,1]
  813. batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
  814. attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
  815. if main_relative_position_buckets is None:
  816. batch_size, sequence_length = hidden_states.shape[:2]
  817. relative_positions = (
  818. torch.arange(1, attn_weights.shape[-1] + 1)
  819. .unsqueeze(0)
  820. .unsqueeze(0)
  821. .repeat(batch_size, sequence_length, 1)
  822. .to(position_ids.device)
  823. )
  824. # [batch_size, sequence_length, sequence_length+1]
  825. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  826. main_relative_position_buckets = compute_relative_buckets(
  827. self.num_buckets, self.relative_max_distance, relative_positions, False
  828. )
  829. # [batch_size, sequence_length, num_buckets * num_heads]
  830. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  831. rel_pos_embeddings = rel_pos_embeddings.view(
  832. rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
  833. )
  834. rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
  835. # [batch_size, num_heads, sequence_length, num_buckets]
  836. rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
  837. main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
  838. # [batch_size * num_heads * sequence_length, sequence_length]
  839. main_relative_position_buckets = main_relative_position_buckets.view(
  840. -1, main_relative_position_buckets.shape[-1]
  841. )
  842. main_relative_position_buckets = main_relative_position_buckets.long()
  843. # [batch_size * num_heads * sequence_length, sequence_length]
  844. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
  845. main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
  846. main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
  847. return main_relative_pos_embeddings
  848. def get_predict_relative_pos_embeddings(
  849. self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
  850. ):
  851. # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
  852. # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
  853. # input position_ids [batch_size, sequence_length] or [1,1]
  854. # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
  855. batch_size, sequence_length = hidden_states.shape[0:2]
  856. if predict_relative_position_buckets is None:
  857. key_sequence_length = attn_weights.shape[-1]
  858. assert (
  859. position_ids[0][0] == key_sequence_length - 1
  860. ), "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)"
  861. relative_positions = (
  862. torch.arange(0, key_sequence_length)
  863. .unsqueeze(0)
  864. .unsqueeze(0)
  865. .repeat(batch_size, sequence_length, 1)
  866. .to(position_ids.device)
  867. )
  868. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  869. predict_relative_position_buckets = compute_relative_buckets(
  870. self.num_buckets, self.relative_max_distance, relative_positions, False
  871. )
  872. # [batch_size, ngram, sequence_length, hidden_size]
  873. hidden_states = hidden_states.transpose(1, 2)
  874. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  875. # [batch_size, ngram, sequence_length, num_buckets, num_heads]
  876. rel_pos_embeddings = rel_pos_embeddings.view(
  877. hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
  878. )
  879. rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
  880. # [batch_size * ngram * sequence_length * num_heads, num_buckets]
  881. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
  882. # [ngram, batch_size, num_heads * sequence_length, -1]
  883. predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
  884. predict_relative_position_buckets = predict_relative_position_buckets.repeat(
  885. self.ngram, 1, self.num_attn_heads, 1
  886. )
  887. # [ngram * batch_size * num_heads * sequence_length, -1]
  888. predict_relative_position_buckets = predict_relative_position_buckets.view(
  889. -1, predict_relative_position_buckets.size(-1)
  890. ).long()
  891. predict_relative_pos_embeddings = torch.gather(
  892. rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
  893. )
  894. # [batch_size, gram, num_heads, sequence_length, -1]
  895. predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
  896. batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
  897. )
  898. return predict_relative_pos_embeddings
  899. class ProphetNetEncoderLayer(nn.Module):
  900. """
  901. Encoder block for Prophetnet
  902. """
  903. def __init__(self, config: ProphetNetConfig):
  904. super().__init__()
  905. # 1st residual block
  906. self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads)
  907. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  908. # 2nd residual block
  909. self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
  910. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  911. def forward(
  912. self,
  913. hidden_states,
  914. attention_mask,
  915. layer_head_mask,
  916. output_attentions: bool = False,
  917. ):
  918. # 1st residual block
  919. attention_output, attn_weights, _ = self.self_attn(
  920. hidden_states=hidden_states,
  921. attention_mask=attention_mask,
  922. layer_head_mask=layer_head_mask,
  923. output_attentions=output_attentions,
  924. )
  925. hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
  926. # 2nd residual block
  927. feed_forward_output = self.feed_forward(hidden_states)
  928. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  929. outputs = (hidden_states,)
  930. if output_attentions:
  931. outputs += (attn_weights,)
  932. return outputs
  933. class ProphetNetDecoderLayer(nn.Module):
  934. """
  935. Decoder block for Prophetnet
  936. """
  937. def __init__(self, config: ProphetNetConfig):
  938. super().__init__()
  939. # 1st residual block
  940. self.self_attn = ProphetNetNgramSelfAttention(config)
  941. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  942. # 2nd residual block
  943. if config.add_cross_attention:
  944. self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads)
  945. self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
  946. # 3rd residual block
  947. self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim)
  948. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  949. def forward(
  950. self,
  951. hidden_states,
  952. attention_mask=None,
  953. encoder_hidden_states=None,
  954. encoder_attn_mask=None,
  955. layer_head_mask=None,
  956. cross_attn_layer_head_mask=None,
  957. extended_predict_attention_mask=None,
  958. main_relative_position_buckets=None,
  959. predict_relative_position_buckets=None,
  960. position_ids=None,
  961. past_key_value=None,
  962. use_cache: bool = True,
  963. output_attentions: bool = False,
  964. ):
  965. # 1st residual block
  966. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  967. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  968. ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
  969. hidden_states=hidden_states,
  970. past_key_value=self_attn_past_key_value,
  971. attention_mask=attention_mask,
  972. layer_head_mask=layer_head_mask,
  973. extended_predict_attention_mask=extended_predict_attention_mask,
  974. main_relative_position_buckets=main_relative_position_buckets,
  975. predict_relative_position_buckets=predict_relative_position_buckets,
  976. position_ids=position_ids,
  977. )
  978. hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
  979. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  980. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  981. cross_attn_weights = None
  982. if encoder_hidden_states is not None:
  983. # 2nd residual block
  984. attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
  985. hidden_states=hidden_states,
  986. key_value_states=encoder_hidden_states,
  987. attention_mask=encoder_attn_mask,
  988. layer_head_mask=cross_attn_layer_head_mask,
  989. past_key_value=cross_attn_past_key_value,
  990. output_attentions=output_attentions,
  991. )
  992. hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
  993. # add cross-attn to positions 3,4 of present_key_value tuple
  994. present_key_value = present_key_value + cross_attn_present_key_value
  995. # 3rd residual block
  996. feed_forward_output = self.feed_forward(hidden_states)
  997. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  998. outputs = (hidden_states,)
  999. if output_attentions:
  1000. outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
  1001. if use_cache:
  1002. outputs += (present_key_value,)
  1003. return outputs
  1004. @add_start_docstrings(
  1005. "The standalone encoder part of the ProphetNetModel.",
  1006. PROPHETNET_START_DOCSTRING,
  1007. )
  1008. class ProphetNetEncoder(ProphetNetPreTrainedModel):
  1009. r"""
  1010. word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
  1011. The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
  1012. embeddings instead of randomly initialized word embeddings.
  1013. """
  1014. def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = None):
  1015. super().__init__(config)
  1016. self.word_embeddings = (
  1017. word_embeddings
  1018. if word_embeddings is not None
  1019. else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1020. )
  1021. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  1022. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  1023. self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
  1024. self.gradient_checkpointing = False
  1025. # Initialize weights and apply final processing
  1026. self.post_init()
  1027. def get_input_embeddings(self):
  1028. return self.word_embeddings
  1029. def set_input_embeddings(self, value):
  1030. self.word_embeddings = value
  1031. @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
  1032. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  1033. def forward(
  1034. self,
  1035. input_ids: Optional[torch.Tensor] = None,
  1036. attention_mask: Optional[torch.Tensor] = None,
  1037. head_mask: Optional[torch.Tensor] = None,
  1038. inputs_embeds: Optional[torch.Tensor] = None,
  1039. output_attentions: Optional[bool] = None,
  1040. output_hidden_states: Optional[bool] = None,
  1041. return_dict: Optional[bool] = None,
  1042. ) -> Union[Tuple, BaseModelOutput]:
  1043. r"""
  1044. Returns:
  1045. Example:
  1046. ```python
  1047. >>> from transformers import AutoTokenizer, ProphetNetEncoder
  1048. >>> import torch
  1049. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1050. >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
  1051. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1052. >>> outputs = model(**inputs)
  1053. >>> last_hidden_states = outputs.last_hidden_state
  1054. ```"""
  1055. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1056. output_hidden_states = (
  1057. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1058. )
  1059. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1060. if input_ids is None and inputs_embeds is None:
  1061. raise ValueError("Either input_ids or inputs_embeds has to be passed.")
  1062. elif input_ids is not None and inputs_embeds is not None:
  1063. raise ValueError("Make sure to only pass input_ids or inputs_embeds.")
  1064. elif input_ids is not None and inputs_embeds is None:
  1065. inputs_embeds = self.word_embeddings(input_ids)
  1066. # prepare attention mask
  1067. if attention_mask is not None:
  1068. extended_attention_mask = (
  1069. 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
  1070. ) * torch.finfo(self.dtype).min
  1071. extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
  1072. else:
  1073. extended_attention_mask = None
  1074. position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device)
  1075. hidden_states = inputs_embeds + position_embeddings
  1076. hidden_states = self.embeddings_layer_norm(hidden_states)
  1077. hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training)
  1078. encoder_hidden_states = () if output_hidden_states else None
  1079. all_attentions = () if output_attentions else None
  1080. # check if head_mask has a correct number of layers specified if desired
  1081. if head_mask is not None:
  1082. assert head_mask.size()[0] == (
  1083. len(self.layers)
  1084. ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  1085. for idx, encoder_layer in enumerate(self.layers):
  1086. if output_hidden_states:
  1087. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  1088. if self.gradient_checkpointing and self.training:
  1089. layer_outputs = self._gradient_checkpointing_func(
  1090. encoder_layer.__call__,
  1091. hidden_states,
  1092. extended_attention_mask,
  1093. (head_mask[idx] if head_mask is not None else None),
  1094. output_attentions,
  1095. )
  1096. else:
  1097. layer_outputs = encoder_layer(
  1098. hidden_states,
  1099. attention_mask=extended_attention_mask,
  1100. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1101. output_attentions=output_attentions,
  1102. )
  1103. hidden_states = layer_outputs[0]
  1104. if output_attentions:
  1105. all_attentions = all_attentions + (layer_outputs[1],)
  1106. if output_hidden_states:
  1107. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  1108. if not return_dict:
  1109. return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None)
  1110. return BaseModelOutput(
  1111. last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions
  1112. )
  1113. @add_start_docstrings(
  1114. "The standalone decoder part of the ProphetNetModel.",
  1115. PROPHETNET_START_DOCSTRING,
  1116. )
  1117. class ProphetNetDecoder(ProphetNetPreTrainedModel):
  1118. r"""
  1119. word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
  1120. The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
  1121. embeddings instead of randomly initialized word embeddings.
  1122. """
  1123. def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
  1124. super().__init__(config)
  1125. self.ngram = config.ngram
  1126. self.num_buckets = config.num_buckets
  1127. self.relative_max_distance = config.relative_max_distance
  1128. self.dropout = config.dropout
  1129. self.max_target_positions = config.max_position_embeddings
  1130. self.word_embeddings = (
  1131. word_embeddings
  1132. if word_embeddings is not None
  1133. else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1134. )
  1135. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  1136. self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
  1137. self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
  1138. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  1139. self.gradient_checkpointing = False
  1140. # Initialize weights and apply final processing
  1141. self.post_init()
  1142. def get_input_embeddings(self):
  1143. return self.word_embeddings
  1144. def set_input_embeddings(self, value):
  1145. self.word_embeddings = value
  1146. @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
  1147. @replace_return_docstrings(output_type=ProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC)
  1148. def forward(
  1149. self,
  1150. input_ids: Optional[torch.Tensor] = None,
  1151. attention_mask: Optional[torch.Tensor] = None,
  1152. encoder_hidden_states: Optional[torch.Tensor] = None,
  1153. encoder_attention_mask: Optional[torch.Tensor] = None,
  1154. head_mask: Optional[torch.Tensor] = None,
  1155. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1156. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1157. inputs_embeds: Optional[torch.Tensor] = None,
  1158. use_cache: Optional[bool] = None,
  1159. output_attentions: Optional[bool] = None,
  1160. output_hidden_states: Optional[bool] = None,
  1161. return_dict: Optional[bool] = None,
  1162. ) -> Union[Tuple, ProphetNetDecoderModelOutput]:
  1163. r"""
  1164. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1165. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  1166. the model is configured as a decoder.
  1167. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1168. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  1169. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1170. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1171. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1172. - 1 indicates the head is **not masked**,
  1173. - 0 indicates the head is **masked**.
  1174. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1175. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
  1176. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1177. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1178. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1179. use_cache (`bool`, *optional*):
  1180. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1181. `past_key_values`).
  1182. - 1 for tokens that are **not masked**,
  1183. - 0 for tokens that are **masked**.
  1184. Returns:
  1185. Example:
  1186. ```python
  1187. >>> from transformers import AutoTokenizer, ProphetNetDecoder
  1188. >>> import torch
  1189. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1190. >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False)
  1191. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1192. >>> outputs = model(**inputs)
  1193. >>> last_hidden_states = outputs.last_hidden_state
  1194. ```"""
  1195. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1196. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1197. output_hidden_states = (
  1198. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1199. )
  1200. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1201. if input_ids is None and inputs_embeds is None:
  1202. raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.")
  1203. elif input_ids is not None and inputs_embeds is not None:
  1204. raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.")
  1205. elif input_ids is not None and inputs_embeds is None:
  1206. inputs_embeds = self.word_embeddings(input_ids)
  1207. batch_size, sequence_length = inputs_embeds.shape[:2]
  1208. main_stream_pos_embed, position_ids = self.position_embeddings(
  1209. (batch_size, sequence_length),
  1210. device=inputs_embeds.device,
  1211. past_key_values=past_key_values,
  1212. )
  1213. if past_key_values is not None:
  1214. main_relative_position_buckets, predict_relative_position_buckets = None, None
  1215. else:
  1216. (
  1217. main_relative_position_buckets,
  1218. predict_relative_position_buckets,
  1219. ) = self.compute_buffered_relative_buckets(position_ids)
  1220. predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1)
  1221. # add position embeddings
  1222. hidden_states = inputs_embeds + main_stream_pos_embed
  1223. ngram_embeddings = self.ngram_embeddings.weight
  1224. # prepare attention mask
  1225. if past_key_values is not None:
  1226. assert (
  1227. hidden_states.size(1) == 1
  1228. ), "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
  1229. ngram_hidden_states = [
  1230. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
  1231. for ngram in range(self.ngram)
  1232. ]
  1233. extended_attention_mask = None
  1234. extended_predict_attention_mask = None
  1235. else:
  1236. ngram_hidden_states = [
  1237. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
  1238. ]
  1239. extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
  1240. extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
  1241. # prepare encoder attention mask
  1242. if encoder_attention_mask is not None:
  1243. extended_encoder_attention_mask = (
  1244. 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
  1245. ) * torch.finfo(self.dtype).min
  1246. extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
  1247. else:
  1248. extended_encoder_attention_mask = None
  1249. hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
  1250. if self.embeddings_layer_norm:
  1251. hidden_states = self.embeddings_layer_norm(hidden_states)
  1252. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1253. # init attentions, hidden_states and cache with empty tuples
  1254. all_main_stream_hidden_states = () if output_hidden_states else None
  1255. all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
  1256. all_main_stream_attns = () if output_attentions else None
  1257. all_ngram_stream_attns = () if output_attentions else None
  1258. all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
  1259. if self.gradient_checkpointing and self.training:
  1260. if use_cache:
  1261. logger.warning_once(
  1262. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1263. )
  1264. use_cache = False
  1265. present_key_values = () if use_cache else None
  1266. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1267. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  1268. if attn_mask is not None:
  1269. assert attn_mask.size()[0] == (len(self.layers)), (
  1270. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1271. f" {head_mask.size()[0]}."
  1272. )
  1273. for idx, decoder_layer in enumerate(self.layers):
  1274. if output_hidden_states:
  1275. # grad cannot be kept because tensor is sliced
  1276. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  1277. if self.config.ngram > 0:
  1278. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  1279. past_key_value = past_key_values[idx] if past_key_values is not None else None
  1280. if self.gradient_checkpointing and self.training:
  1281. layer_outputs = self._gradient_checkpointing_func(
  1282. decoder_layer.__call__,
  1283. hidden_states,
  1284. extended_attention_mask,
  1285. encoder_hidden_states,
  1286. extended_encoder_attention_mask,
  1287. (head_mask[idx] if head_mask is not None else None),
  1288. (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  1289. extended_predict_attention_mask,
  1290. main_relative_position_buckets,
  1291. predict_relative_position_buckets,
  1292. position_ids,
  1293. None,
  1294. use_cache,
  1295. output_attentions,
  1296. )
  1297. else:
  1298. layer_outputs = decoder_layer(
  1299. hidden_states,
  1300. attention_mask=extended_attention_mask,
  1301. encoder_hidden_states=encoder_hidden_states,
  1302. encoder_attn_mask=extended_encoder_attention_mask,
  1303. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1304. cross_attn_layer_head_mask=(
  1305. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  1306. ),
  1307. extended_predict_attention_mask=extended_predict_attention_mask,
  1308. main_relative_position_buckets=main_relative_position_buckets,
  1309. predict_relative_position_buckets=predict_relative_position_buckets,
  1310. position_ids=position_ids,
  1311. past_key_value=past_key_value,
  1312. use_cache=use_cache,
  1313. output_attentions=output_attentions,
  1314. )
  1315. hidden_states = layer_outputs[0]
  1316. if use_cache:
  1317. present_key_values += (layer_outputs[4 if output_attentions else 1],)
  1318. if output_attentions:
  1319. all_main_stream_attns += (layer_outputs[1],)
  1320. all_ngram_stream_attns += (layer_outputs[2],)
  1321. if self.config.add_cross_attention:
  1322. all_cross_attns += (layer_outputs[3],)
  1323. if output_hidden_states:
  1324. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  1325. if self.config.ngram > 0:
  1326. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  1327. # split last_hidden_state for return
  1328. last_hidden_state = hidden_states[:, :sequence_length]
  1329. last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
  1330. if not return_dict:
  1331. return tuple(
  1332. v
  1333. for v in [
  1334. last_hidden_state,
  1335. last_hidden_state_ngram,
  1336. present_key_values,
  1337. all_main_stream_hidden_states,
  1338. all_ngram_stream_hidden_states,
  1339. all_main_stream_attns,
  1340. all_ngram_stream_attns,
  1341. all_cross_attns,
  1342. ]
  1343. if v is not None
  1344. )
  1345. return ProphetNetDecoderModelOutput(
  1346. last_hidden_state=last_hidden_state,
  1347. last_hidden_state_ngram=last_hidden_state_ngram,
  1348. past_key_values=present_key_values,
  1349. hidden_states=all_main_stream_hidden_states,
  1350. hidden_states_ngram=all_ngram_stream_hidden_states,
  1351. attentions=all_main_stream_attns,
  1352. ngram_attentions=all_ngram_stream_attns,
  1353. cross_attentions=all_cross_attns,
  1354. )
  1355. def compute_buffered_relative_buckets(self, position_ids):
  1356. batch_size, sequence_length = position_ids.shape
  1357. position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1)
  1358. main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets(
  1359. self.num_buckets, self.relative_max_distance, position_ids
  1360. )
  1361. # buffer relative buckets
  1362. main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1)
  1363. predict_relative_buckets = torch.cat(
  1364. [
  1365. predict_relative_buckets[:, :sequence_length, :sequence_length],
  1366. predict_relative_buckets[
  1367. :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length
  1368. ],
  1369. ],
  1370. 2,
  1371. ).repeat(batch_size, 1, 1)
  1372. return main_relative_buckets, predict_relative_buckets
  1373. def prepare_attention_mask(self, hidden_states, attention_mask):
  1374. batch_size, seq_length = hidden_states.shape[:2]
  1375. # get causal mask
  1376. causal_mask = torch.full(
  1377. (seq_length, seq_length),
  1378. torch.finfo(hidden_states.dtype).min,
  1379. dtype=hidden_states.dtype,
  1380. device=hidden_states.device,
  1381. )
  1382. causal_mask = torch.triu(causal_mask, 1)
  1383. extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
  1384. (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
  1385. )
  1386. # add usual attention mask
  1387. if attention_mask is not None:
  1388. extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
  1389. extended_attention_mask = extended_causal_mask + extended_attention_mask
  1390. else:
  1391. extended_attention_mask = extended_causal_mask
  1392. return extended_attention_mask.to(hidden_states.dtype)
  1393. def prepare_predict_attention_mask(self, hidden_states, attention_mask):
  1394. batch_size, seq_length = hidden_states.shape[:2]
  1395. # get causal mask
  1396. predict_causal_mask = ngram_attention_bias(
  1397. self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
  1398. )
  1399. predict_causal_mask = torch.cat(
  1400. [
  1401. predict_causal_mask[:, :seq_length, :seq_length],
  1402. predict_causal_mask[
  1403. :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length
  1404. ],
  1405. ],
  1406. dim=-1,
  1407. )
  1408. extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
  1409. (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
  1410. )
  1411. # add usual attention mask
  1412. if attention_mask is not None:
  1413. extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
  1414. extended_attention_mask = extended_attention_mask.expand(
  1415. (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
  1416. )
  1417. # predicted stream attention_mask should always be 0
  1418. extended_attention_mask = torch.cat(
  1419. [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
  1420. )
  1421. extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
  1422. else:
  1423. extended_predict_attention_mask = extended_predict_causal_mask
  1424. return extended_predict_attention_mask.to(hidden_states.dtype)
  1425. @add_start_docstrings(
  1426. "The bare ProphetNet Model outputting raw hidden-states without any specific head on top.",
  1427. PROPHETNET_START_DOCSTRING,
  1428. )
  1429. class ProphetNetModel(ProphetNetPreTrainedModel):
  1430. _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
  1431. def __init__(self, config: ProphetNetConfig):
  1432. super().__init__(config)
  1433. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1434. encoder_config = copy.deepcopy(config)
  1435. encoder_config.is_encoder_decoder = False
  1436. encoder_config.use_cache = False
  1437. self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings)
  1438. decoder_config = copy.deepcopy(config)
  1439. decoder_config.is_decoder = True
  1440. decoder_config.is_encoder_decoder = False
  1441. self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings)
  1442. # Initialize weights and apply final processing
  1443. self.post_init()
  1444. def get_input_embeddings(self):
  1445. return self.word_embeddings
  1446. def set_input_embeddings(self, value):
  1447. self.word_embeddings = value
  1448. self.encoder.word_embeddings = self.word_embeddings
  1449. self.decoder.word_embeddings = self.word_embeddings
  1450. def _tie_weights(self):
  1451. if self.config.tie_word_embeddings:
  1452. self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
  1453. self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)
  1454. def get_encoder(self):
  1455. return self.encoder
  1456. def get_decoder(self):
  1457. return self.decoder
  1458. @add_start_docstrings_to_model_forward(PROPHETNET_INPUTS_DOCSTRING)
  1459. @replace_return_docstrings(output_type=ProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
  1460. def forward(
  1461. self,
  1462. input_ids: Optional[torch.Tensor] = None,
  1463. attention_mask: Optional[torch.Tensor] = None,
  1464. decoder_input_ids: Optional[torch.Tensor] = None,
  1465. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1466. head_mask: Optional[torch.Tensor] = None,
  1467. decoder_head_mask: Optional[torch.Tensor] = None,
  1468. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1469. encoder_outputs: Optional[Tuple] = None,
  1470. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1471. inputs_embeds: Optional[torch.Tensor] = None,
  1472. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1473. use_cache: Optional[bool] = None,
  1474. output_attentions: Optional[bool] = None,
  1475. output_hidden_states: Optional[bool] = None,
  1476. return_dict: Optional[bool] = None,
  1477. ) -> Union[Tuple, ProphetNetSeq2SeqModelOutput]:
  1478. r"""
  1479. Returns:
  1480. Example:
  1481. ```python
  1482. >>> from transformers import AutoTokenizer, ProphetNetModel
  1483. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1484. >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased")
  1485. >>> input_ids = tokenizer(
  1486. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1487. ... ).input_ids # Batch size 1
  1488. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1489. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1490. >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
  1491. >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
  1492. ```"""
  1493. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1494. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1495. output_hidden_states = (
  1496. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1497. )
  1498. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1499. if encoder_outputs is None:
  1500. encoder_outputs = self.encoder(
  1501. input_ids=input_ids,
  1502. attention_mask=attention_mask,
  1503. head_mask=head_mask,
  1504. inputs_embeds=inputs_embeds,
  1505. output_attentions=output_attentions,
  1506. output_hidden_states=output_hidden_states,
  1507. return_dict=return_dict,
  1508. )
  1509. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1510. decoder_outputs = self.decoder(
  1511. input_ids=decoder_input_ids,
  1512. attention_mask=decoder_attention_mask,
  1513. encoder_hidden_states=encoder_outputs[0],
  1514. encoder_attention_mask=attention_mask,
  1515. head_mask=decoder_head_mask,
  1516. cross_attn_head_mask=cross_attn_head_mask,
  1517. past_key_values=past_key_values,
  1518. inputs_embeds=decoder_inputs_embeds,
  1519. output_attentions=output_attentions,
  1520. output_hidden_states=output_hidden_states,
  1521. use_cache=use_cache,
  1522. return_dict=return_dict,
  1523. )
  1524. if not return_dict:
  1525. return decoder_outputs + encoder_outputs
  1526. return ProphetNetSeq2SeqModelOutput(
  1527. last_hidden_state=decoder_outputs.last_hidden_state,
  1528. last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram,
  1529. past_key_values=decoder_outputs.past_key_values,
  1530. decoder_hidden_states=decoder_outputs.hidden_states,
  1531. decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
  1532. decoder_attentions=decoder_outputs.attentions,
  1533. decoder_ngram_attentions=decoder_outputs.ngram_attentions,
  1534. cross_attentions=decoder_outputs.cross_attentions,
  1535. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1536. encoder_hidden_states=encoder_outputs.hidden_states,
  1537. encoder_attentions=encoder_outputs.attentions,
  1538. )
  1539. @add_start_docstrings(
  1540. "The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.",
  1541. PROPHETNET_START_DOCSTRING,
  1542. )
  1543. class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin):
  1544. _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
  1545. def __init__(self, config: ProphetNetConfig):
  1546. super().__init__(config)
  1547. self.prophetnet = ProphetNetModel(config)
  1548. self.padding_idx = config.pad_token_id
  1549. self.disable_ngram_loss = config.disable_ngram_loss
  1550. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1551. # Initialize weights and apply final processing
  1552. self.post_init()
  1553. def get_output_embeddings(self):
  1554. return self.lm_head
  1555. def set_output_embeddings(self, new_embeddings):
  1556. self.lm_head = new_embeddings
  1557. def _tie_weights(self):
  1558. if self.config.tie_word_embeddings:
  1559. self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)
  1560. def get_input_embeddings(self):
  1561. return self.prophetnet.word_embeddings
  1562. @add_start_docstrings_to_model_forward(PROPHETNET_INPUTS_DOCSTRING)
  1563. @replace_return_docstrings(output_type=ProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1564. def forward(
  1565. self,
  1566. input_ids: Optional[torch.Tensor] = None,
  1567. attention_mask: Optional[torch.Tensor] = None,
  1568. decoder_input_ids: Optional[torch.Tensor] = None,
  1569. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1570. head_mask: Optional[torch.Tensor] = None,
  1571. decoder_head_mask: Optional[torch.Tensor] = None,
  1572. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1573. encoder_outputs: Optional[torch.Tensor] = None,
  1574. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1575. inputs_embeds: Optional[torch.Tensor] = None,
  1576. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1577. labels: Optional[torch.Tensor] = None,
  1578. use_cache: Optional[bool] = None,
  1579. output_attentions: Optional[bool] = None,
  1580. output_hidden_states: Optional[bool] = None,
  1581. return_dict: Optional[bool] = None,
  1582. ) -> Union[Tuple, ProphetNetSeq2SeqLMOutput]:
  1583. r"""
  1584. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1585. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1586. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1587. labels in `[0, ..., config.vocab_size]`
  1588. Returns:
  1589. Example:
  1590. ```python
  1591. >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration
  1592. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1593. >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")
  1594. >>> input_ids = tokenizer(
  1595. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1596. ... ).input_ids # Batch size 1
  1597. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1598. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1599. >>> logits_next_token = outputs.logits # logits to predict next token as usual
  1600. >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens
  1601. ```"""
  1602. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1603. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1604. # get decoder inputs from shifting lm labels to the right
  1605. decoder_input_ids = self._shift_right(labels)
  1606. outputs = self.prophetnet(
  1607. input_ids=input_ids,
  1608. attention_mask=attention_mask,
  1609. decoder_input_ids=decoder_input_ids,
  1610. decoder_attention_mask=decoder_attention_mask,
  1611. head_mask=head_mask,
  1612. decoder_head_mask=decoder_head_mask,
  1613. cross_attn_head_mask=cross_attn_head_mask,
  1614. encoder_outputs=encoder_outputs,
  1615. past_key_values=past_key_values,
  1616. inputs_embeds=inputs_embeds,
  1617. decoder_inputs_embeds=decoder_inputs_embeds,
  1618. use_cache=use_cache,
  1619. output_attentions=output_attentions,
  1620. output_hidden_states=output_hidden_states,
  1621. return_dict=return_dict,
  1622. )
  1623. batch_size, sequence_length = (
  1624. decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2]
  1625. )
  1626. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1627. predict_logits = self.lm_head(predicting_streams)
  1628. logits = predict_logits[:, 0]
  1629. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1630. # To use .view in loss computation, make sure that logits is contiguous.
  1631. if not logits.is_contiguous():
  1632. logits = logits.contiguous()
  1633. loss = None
  1634. if labels is not None:
  1635. loss = self._compute_loss(predict_logits, labels)
  1636. if not return_dict:
  1637. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1638. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1639. else:
  1640. return ProphetNetSeq2SeqLMOutput(
  1641. loss=loss,
  1642. logits=logits,
  1643. logits_ngram=logits_ngram,
  1644. past_key_values=outputs.past_key_values,
  1645. decoder_hidden_states=outputs.decoder_hidden_states,
  1646. decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
  1647. decoder_attentions=outputs.decoder_attentions,
  1648. decoder_ngram_attentions=outputs.decoder_ngram_attentions,
  1649. cross_attentions=outputs.cross_attentions,
  1650. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1651. encoder_hidden_states=outputs.encoder_hidden_states,
  1652. encoder_attentions=outputs.encoder_attentions,
  1653. )
  1654. def _compute_loss(self, logits, labels, ignore_index=-100):
  1655. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1656. for i in range(self.config.ngram):
  1657. if i > 0 and self.disable_ngram_loss:
  1658. break
  1659. expend_targets[i, :, :] = labels
  1660. logits = logits.transpose(0, 1).contiguous()
  1661. lprobs = nn.functional.log_softmax(
  1662. logits.view(-1, logits.size(-1)),
  1663. dim=-1,
  1664. dtype=torch.float32,
  1665. )
  1666. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1667. if self.config.eps > 0.0:
  1668. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1669. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1670. smooth_loss = smooth_loss[non_masked_tokens]
  1671. smooth_loss = smooth_loss.mean()
  1672. eps_i = self.config.eps / lprobs.size(-1)
  1673. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1674. return loss
  1675. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1676. return self._shift_right(labels)
  1677. @staticmethod
  1678. # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
  1679. def _reorder_cache(past_key_values, beam_idx):
  1680. reordered_past = ()
  1681. for layer_past in past_key_values:
  1682. # cached cross_attention states don't have to be reordered -> they are always the same
  1683. reordered_past += (
  1684. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
  1685. + layer_past[2:],
  1686. )
  1687. return reordered_past
  1688. def get_encoder(self):
  1689. return self.prophetnet.encoder
  1690. def get_decoder(self):
  1691. return self.prophetnet.decoder
  1692. @add_start_docstrings(
  1693. "The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal"
  1694. " language modeling.",
  1695. PROPHETNET_START_DOCSTRING,
  1696. )
  1697. class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin):
  1698. _tied_weights_keys = [
  1699. "prophetnet.word_embeddings.weight",
  1700. "prophetnet.decoder.word_embeddings.weight",
  1701. "lm_head.weight",
  1702. ]
  1703. def __init__(self, config: ProphetNetConfig):
  1704. # set config for CLM
  1705. config = copy.deepcopy(config)
  1706. config.is_decoder = True
  1707. config.is_encoder_decoder = False
  1708. super().__init__(config)
  1709. self.prophetnet = ProphetNetDecoderWrapper(config)
  1710. self.padding_idx = config.pad_token_id
  1711. self.disable_ngram_loss = config.disable_ngram_loss
  1712. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1713. # Initialize weights and apply final processing
  1714. self.post_init()
  1715. def get_input_embeddings(self):
  1716. return self.prophetnet.decoder.word_embeddings
  1717. def set_input_embeddings(self, value):
  1718. self.prophetnet.decoder.word_embeddings = value
  1719. def get_output_embeddings(self):
  1720. return self.lm_head
  1721. def set_output_embeddings(self, new_embeddings):
  1722. self.lm_head = new_embeddings
  1723. def _tie_weights(self):
  1724. if self.config.tie_word_embeddings:
  1725. self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)
  1726. def set_decoder(self, decoder):
  1727. self.prophetnet.decoder = decoder
  1728. def get_decoder(self):
  1729. return self.prophetnet.decoder
  1730. @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
  1731. @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
  1732. def forward(
  1733. self,
  1734. input_ids: Optional[torch.Tensor] = None,
  1735. attention_mask: Optional[torch.Tensor] = None,
  1736. encoder_hidden_states: Optional[torch.Tensor] = None,
  1737. encoder_attention_mask: Optional[torch.Tensor] = None,
  1738. head_mask: Optional[torch.Tensor] = None,
  1739. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1740. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1741. inputs_embeds: Optional[torch.Tensor] = None,
  1742. labels: Optional[torch.Tensor] = None,
  1743. use_cache: Optional[bool] = None,
  1744. output_attentions: Optional[bool] = None,
  1745. output_hidden_states: Optional[bool] = None,
  1746. return_dict: Optional[bool] = None,
  1747. ) -> Union[Tuple, ProphetNetDecoderLMOutput]:
  1748. r"""
  1749. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1750. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  1751. the model is configured as a decoder.
  1752. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1753. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  1754. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  1755. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1756. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1757. - 1 indicates the head is **not masked**,
  1758. - 0 indicates the head is **masked**.
  1759. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1760. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
  1761. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1762. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1763. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1764. use_cache (`bool`, *optional*):
  1765. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1766. `past_key_values`).
  1767. - 1 for tokens that are **not masked**,
  1768. - 0 for tokens that are **masked**.
  1769. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1770. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1771. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1772. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1773. Returns:
  1774. Example:
  1775. ```python
  1776. >>> from transformers import AutoTokenizer, ProphetNetForCausalLM
  1777. >>> import torch
  1778. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1779. >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased")
  1780. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1781. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1782. >>> outputs = model(**inputs)
  1783. >>> logits = outputs.logits
  1784. >>> # Model can also be used with EncoderDecoder framework
  1785. >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
  1786. >>> import torch
  1787. >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
  1788. >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1789. >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
  1790. ... "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased"
  1791. ... )
  1792. >>> ARTICLE = (
  1793. ... "the us state department said wednesday it had received no "
  1794. ... "formal word from bolivia that it was expelling the us ambassador there "
  1795. ... "but said the charges made against him are `` baseless ."
  1796. ... )
  1797. >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
  1798. >>> labels = tokenizer_dec(
  1799. ... "us rejects charges against its ambassador in bolivia", return_tensors="pt"
  1800. ... ).input_ids
  1801. >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])
  1802. >>> loss = outputs.loss
  1803. ```"""
  1804. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1805. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1806. outputs = self.prophetnet.decoder(
  1807. input_ids=input_ids,
  1808. attention_mask=attention_mask,
  1809. encoder_hidden_states=encoder_hidden_states,
  1810. encoder_attention_mask=encoder_attention_mask,
  1811. head_mask=head_mask,
  1812. cross_attn_head_mask=cross_attn_head_mask,
  1813. past_key_values=past_key_values,
  1814. inputs_embeds=inputs_embeds,
  1815. use_cache=use_cache,
  1816. output_attentions=output_attentions,
  1817. output_hidden_states=output_hidden_states,
  1818. return_dict=return_dict,
  1819. )
  1820. batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
  1821. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1822. predict_logits = self.lm_head(predicting_streams)
  1823. logits = predict_logits[:, 0]
  1824. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1825. loss = None
  1826. if labels is not None:
  1827. loss = self._compute_loss(predict_logits, labels)
  1828. if not return_dict:
  1829. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1830. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1831. else:
  1832. return ProphetNetDecoderLMOutput(
  1833. loss=loss,
  1834. logits=logits,
  1835. logits_ngram=logits_ngram,
  1836. past_key_values=outputs.past_key_values,
  1837. hidden_states=outputs.hidden_states,
  1838. hidden_states_ngram=outputs.hidden_states_ngram,
  1839. attentions=outputs.attentions,
  1840. ngram_attentions=outputs.ngram_attentions,
  1841. cross_attentions=outputs.cross_attentions,
  1842. )
  1843. def _compute_loss(self, logits, labels, ignore_index=-100):
  1844. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1845. for i in range(self.config.ngram):
  1846. if i > 0 and self.disable_ngram_loss:
  1847. break
  1848. expend_targets[i, :, :] = labels
  1849. logits = logits.transpose(0, 1).contiguous()
  1850. lprobs = nn.functional.log_softmax(
  1851. logits.view(-1, logits.size(-1)),
  1852. dim=-1,
  1853. dtype=torch.float32,
  1854. )
  1855. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1856. if self.config.eps > 0.0:
  1857. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1858. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1859. smooth_loss = smooth_loss[non_masked_tokens]
  1860. smooth_loss = smooth_loss.mean()
  1861. eps_i = self.config.eps / lprobs.size(-1)
  1862. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1863. return loss
  1864. def prepare_inputs_for_generation(
  1865. self,
  1866. input_ids,
  1867. past_key_values=None,
  1868. attention_mask=None,
  1869. head_mask=None,
  1870. use_cache=None,
  1871. **kwargs,
  1872. ):
  1873. # Overwritten -- our tests complain if we use GenerationMixin.prepare_inputs_for_generation
  1874. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1875. if attention_mask is None:
  1876. attention_mask = input_ids.new_ones(input_ids.shape)
  1877. if past_key_values:
  1878. input_ids = input_ids[:, -1:]
  1879. # first step, decoder_cached_states are empty
  1880. return {
  1881. "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
  1882. "attention_mask": attention_mask,
  1883. "head_mask": head_mask,
  1884. "past_key_values": past_key_values,
  1885. "use_cache": use_cache,
  1886. }
  1887. @staticmethod
  1888. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache
  1889. def _reorder_cache(past_key_values, beam_idx):
  1890. reordered_past = ()
  1891. for layer_past in past_key_values:
  1892. reordered_past += (
  1893. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1894. )
  1895. return reordered_past
  1896. class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
  1897. """
  1898. This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet
  1899. classes.
  1900. """
  1901. def __init__(self, config: ProphetNetConfig):
  1902. super().__init__(config)
  1903. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1904. self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings)
  1905. # Initialize weights and apply final processing
  1906. self.post_init()
  1907. def _tie_weights(self):
  1908. self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())
  1909. def forward(self, *args, **kwargs):
  1910. return self.decoder(*args, **kwargs)