| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714 |
- # coding=utf-8
- # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch LED model."""
- import math
- import warnings
- from dataclasses import dataclass
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ...activations import ACT2FN
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask
- from ...modeling_outputs import (
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- Seq2SeqQuestionAnsweringModelOutput,
- Seq2SeqSequenceClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_end_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from .configuration_led import LEDConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "allenai/led-base-16384"
- _CONFIG_FOR_DOC = "LEDConfig"
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
- """
- Shift input ids one token to the right.
- """
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
- shifted_input_ids[:, 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- def _prepare_4d_attention_mask_inverted(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
- inverted_mask = 1.0 - expanded_mask
- expanded_attention_mask = inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
- # make sure that global_attn_mask is positive
- expanded_attention_mask = expanded_attention_mask * inverted_mask
- return expanded_attention_mask
- class LEDLearnedPositionalEmbedding(nn.Embedding):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
- def __init__(self, num_embeddings: int, embedding_dim: int):
- super().__init__(num_embeddings, embedding_dim)
- def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
- bsz, seq_len = input_ids_shape[:2]
- positions = torch.arange(
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
- )
- return super().forward(positions)
- # Copied from transformers.models.longformer.modeling_longformer.LongformerSelfAttention with Longformer->LEDEncoder
- class LEDEncoderSelfAttention(nn.Module):
- def __init__(self, config, layer_id):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_heads = config.num_attention_heads
- self.head_dim = int(config.hidden_size / config.num_attention_heads)
- self.embed_dim = config.hidden_size
- self.query = nn.Linear(config.hidden_size, self.embed_dim)
- self.key = nn.Linear(config.hidden_size, self.embed_dim)
- self.value = nn.Linear(config.hidden_size, self.embed_dim)
- # separate projection layers for tokens with global attention
- self.query_global = nn.Linear(config.hidden_size, self.embed_dim)
- self.key_global = nn.Linear(config.hidden_size, self.embed_dim)
- self.value_global = nn.Linear(config.hidden_size, self.embed_dim)
- self.dropout = config.attention_probs_dropout_prob
- self.layer_id = layer_id
- attention_window = config.attention_window[self.layer_id]
- assert (
- attention_window % 2 == 0
- ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
- assert (
- attention_window > 0
- ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
- self.one_sided_attn_window_size = attention_window // 2
- self.config = config
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- layer_head_mask=None,
- is_index_masked=None,
- is_index_global_attn=None,
- is_global_attn=None,
- output_attentions=False,
- ):
- """
- [`LEDEncoderSelfAttention`] expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
- *attention_window* happens in [`LEDEncoderModel.forward`] to avoid redoing the padding on each layer.
- The *attention_mask* is changed in [`LEDEncoderModel.forward`] from 0, 1, 2 to:
- - -10000: no attention
- - 0: local attention
- - +10000: global attention
- """
- hidden_states = hidden_states.transpose(0, 1)
- # project hidden states
- query_vectors = self.query(hidden_states)
- key_vectors = self.key(hidden_states)
- value_vectors = self.value(hidden_states)
- seq_len, batch_size, embed_dim = hidden_states.size()
- assert (
- embed_dim == self.embed_dim
- ), f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}"
- # normalize query
- query_vectors /= math.sqrt(self.head_dim)
- query_vectors = query_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
- key_vectors = key_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
- attn_scores = self._sliding_chunks_query_key_matmul(
- query_vectors, key_vectors, self.one_sided_attn_window_size
- )
- # values to pad for attention probs
- remove_from_windowed_attention_mask = (attention_mask != 0)[:, :, None, None]
- # cast to fp32/fp16 then replace 1's with -inf
- float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill(
- remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min
- )
- # diagonal mask with zeros everywhere and -inf inplace of padding
- diagonal_mask = self._sliding_chunks_query_key_matmul(
- float_mask.new_ones(size=float_mask.size()), float_mask, self.one_sided_attn_window_size
- )
- # pad local attention probs
- attn_scores += diagonal_mask
- assert list(attn_scores.size()) == [
- batch_size,
- seq_len,
- self.num_heads,
- self.one_sided_attn_window_size * 2 + 1,
- ], (
- f"local_attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
- f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {attn_scores.size()}"
- )
- # compute local attention probs from global attention keys and contact over window dim
- if is_global_attn:
- # compute global attn indices required through out forward fn
- (
- max_num_global_attn_indices,
- is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero,
- ) = self._get_global_attn_indices(is_index_global_attn)
- # calculate global attn probs from global key
- global_key_attn_scores = self._concat_with_global_key_attn_probs(
- query_vectors=query_vectors,
- key_vectors=key_vectors,
- max_num_global_attn_indices=max_num_global_attn_indices,
- is_index_global_attn_nonzero=is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
- )
- # concat to local_attn_probs
- # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
- attn_scores = torch.cat((global_key_attn_scores, attn_scores), dim=-1)
- # free memory
- del global_key_attn_scores
- attn_probs = nn.functional.softmax(
- attn_scores, dim=-1, dtype=torch.float32
- ) # use fp32 for numerical stability
- if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_heads,
- ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
- attn_probs = layer_head_mask.view(1, 1, -1, 1) * attn_probs
- # softmax sometimes inserts NaN if all positions are masked, replace them with 0
- attn_probs = torch.masked_fill(attn_probs, is_index_masked[:, :, None, None], 0.0)
- attn_probs = attn_probs.type_as(attn_scores)
- # free memory
- del attn_scores
- # apply dropout
- attn_probs = nn.functional.dropout(attn_probs, p=self.dropout, training=self.training)
- value_vectors = value_vectors.view(seq_len, batch_size, self.num_heads, self.head_dim).transpose(0, 1)
- # compute local attention output with global attention value and add
- if is_global_attn:
- # compute sum of global and local attn
- attn_output = self._compute_attn_output_with_global_indices(
- value_vectors=value_vectors,
- attn_probs=attn_probs,
- max_num_global_attn_indices=max_num_global_attn_indices,
- is_index_global_attn_nonzero=is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
- )
- else:
- # compute local attn only
- attn_output = self._sliding_chunks_matmul_attn_probs_value(
- attn_probs, value_vectors, self.one_sided_attn_window_size
- )
- assert attn_output.size() == (batch_size, seq_len, self.num_heads, self.head_dim), "Unexpected size"
- attn_output = attn_output.transpose(0, 1).reshape(seq_len, batch_size, embed_dim).contiguous()
- # compute value for global attention and overwrite to attention output
- # TODO: remove the redundant computation
- if is_global_attn:
- global_attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
- hidden_states=hidden_states,
- max_num_global_attn_indices=max_num_global_attn_indices,
- layer_head_mask=layer_head_mask,
- is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
- is_index_global_attn_nonzero=is_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
- is_index_masked=is_index_masked,
- )
- # get only non zero global attn output
- nonzero_global_attn_output = global_attn_output[
- is_local_index_global_attn_nonzero[0], :, is_local_index_global_attn_nonzero[1]
- ]
- # overwrite values with global attention
- attn_output[is_index_global_attn_nonzero[::-1]] = nonzero_global_attn_output.view(
- len(is_local_index_global_attn_nonzero[0]), -1
- )
- # The attention weights for tokens with global attention are
- # just filler values, they were never used to compute the output.
- # Fill with 0 now, the correct values are in 'global_attn_probs'.
- attn_probs[is_index_global_attn_nonzero] = 0
- outputs = (attn_output.transpose(0, 1),)
- if output_attentions:
- outputs += (attn_probs,)
- return outputs + (global_attn_probs,) if (is_global_attn and output_attentions) else outputs
- @staticmethod
- def _pad_and_transpose_last_two_dims(hidden_states_padded, padding):
- """pads rows and then flips rows and columns"""
- hidden_states_padded = nn.functional.pad(
- hidden_states_padded, padding
- ) # padding value is not important because it will be overwritten
- hidden_states_padded = hidden_states_padded.view(
- *hidden_states_padded.size()[:-2], hidden_states_padded.size(-1), hidden_states_padded.size(-2)
- )
- return hidden_states_padded
- @staticmethod
- def _pad_and_diagonalize(chunked_hidden_states):
- """
- shift every row 1 step right, converting columns into diagonals.
- Example:
- ```python
- chunked_hidden_states: [
- 0.4983,
- 2.6918,
- -0.0071,
- 1.0492,
- -1.8348,
- 0.7672,
- 0.2986,
- 0.0285,
- -0.7584,
- 0.4206,
- -0.0405,
- 0.1599,
- 2.0514,
- -1.1600,
- 0.5372,
- 0.2629,
- ]
- window_overlap = num_rows = 4
- ```
- (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
- 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
- -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
- """
- total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
- chunked_hidden_states = nn.functional.pad(
- chunked_hidden_states, (0, window_overlap + 1)
- ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
- chunked_hidden_states = chunked_hidden_states.view(
- total_num_heads, num_chunks, -1
- ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap
- chunked_hidden_states = chunked_hidden_states[
- :, :, :-window_overlap
- ] # total_num_heads x num_chunks x window_overlap*window_overlap
- chunked_hidden_states = chunked_hidden_states.view(
- total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim
- )
- chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
- return chunked_hidden_states
- @staticmethod
- def _chunk(hidden_states, window_overlap, onnx_export: bool = False):
- """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
- if not onnx_export:
- # non-overlapping chunks of size = 2w
- hidden_states = hidden_states.view(
- hidden_states.size(0),
- torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
- window_overlap * 2,
- hidden_states.size(2),
- )
- # use `as_strided` to make the chunks overlap with an overlap size = window_overlap
- chunk_size = list(hidden_states.size())
- chunk_size[1] = chunk_size[1] * 2 - 1
- chunk_stride = list(hidden_states.stride())
- chunk_stride[1] = chunk_stride[1] // 2
- return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
- # When exporting to ONNX, use this separate logic
- # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
- # TODO replace this with
- # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
- # once `unfold` is supported
- # the case hidden_states.size(1) == window_overlap * 2 can also simply return hidden_states.unsqueeze(1), but that's control flow
- chunk_size = [
- hidden_states.size(0),
- torch.div(hidden_states.size(1), window_overlap, rounding_mode="trunc") - 1,
- window_overlap * 2,
- hidden_states.size(2),
- ]
- overlapping_chunks = torch.empty(chunk_size, device=hidden_states.device)
- for chunk in range(chunk_size[1]):
- overlapping_chunks[:, chunk, :, :] = hidden_states[
- :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
- ]
- return overlapping_chunks
- @staticmethod
- def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
- beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
- beginning_mask = beginning_mask_2d[None, :, None, :]
- ending_mask = beginning_mask.flip(dims=(1, 3))
- beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
- beginning_mask = beginning_mask.expand(beginning_input.size())
- input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
- beginning_input, -float("inf")
- ).where(beginning_mask.bool(), beginning_input)
- ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
- ending_mask = ending_mask.expand(ending_input.size())
- input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
- ending_input, -float("inf")
- ).where(ending_mask.bool(), ending_input)
- def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
- """
- Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
- implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained LEDEncoder) with an
- overlap of size window_overlap
- """
- batch_size, seq_len, num_heads, head_dim = query.size()
- assert (
- seq_len % (window_overlap * 2) == 0
- ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
- assert query.size() == key.size()
- chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
- # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
- query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
- key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
- query = self._chunk(query, window_overlap, getattr(self.config, "onnx_export", False))
- key = self._chunk(key, window_overlap, getattr(self.config, "onnx_export", False))
- # matrix multiplication
- # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
- # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
- # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
- diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (query, key)) # multiply
- # convert diagonals into columns
- diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(
- diagonal_chunked_attention_scores, padding=(0, 0, 0, 1)
- )
- # allocate space for the overall attention matrix where the chunks are combined. The last dimension
- # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
- # window_overlap previous words). The following column is attention score from each word to itself, then
- # followed by window_overlap columns for the upper triangle.
- diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
- (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
- )
- # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
- # - copying the main diagonal and the upper triangle
- diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[
- :, :, :window_overlap, : window_overlap + 1
- ]
- diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[
- :, -1, window_overlap:, : window_overlap + 1
- ]
- # - copying the lower triangle
- diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[
- :, :, -(window_overlap + 1) : -1, window_overlap + 1 :
- ]
- diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[
- :, 0, : window_overlap - 1, 1 - window_overlap :
- ]
- # separate batch_size and num_heads dimensions again
- diagonal_attention_scores = diagonal_attention_scores.view(
- batch_size, num_heads, seq_len, 2 * window_overlap + 1
- ).transpose(2, 1)
- self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
- return diagonal_attention_scores
- def _sliding_chunks_matmul_attn_probs_value(
- self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
- ):
- """
- Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
- same shape as `attn_probs`
- """
- batch_size, seq_len, num_heads, head_dim = value.size()
- assert seq_len % (window_overlap * 2) == 0
- assert attn_probs.size()[:3] == value.size()[:3]
- assert attn_probs.size(3) == 2 * window_overlap + 1
- chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
- # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
- chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
- batch_size * num_heads,
- torch.div(seq_len, window_overlap, rounding_mode="trunc"),
- window_overlap,
- 2 * window_overlap + 1,
- )
- # group batch_size and num_heads dimensions into one
- value = value.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
- # pad seq_len with w at the beginning of the sequence and another window overlap at the end
- padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1)
- # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
- chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim)
- chunked_value_stride = padded_value.stride()
- chunked_value_stride = (
- chunked_value_stride[0],
- window_overlap * chunked_value_stride[1],
- chunked_value_stride[1],
- chunked_value_stride[2],
- )
- chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride)
- chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
- context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value))
- return context.view(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
- @staticmethod
- def _get_global_attn_indices(is_index_global_attn):
- """compute global attn indices required throughout forward pass"""
- # helper variable
- num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
- # max number of global attn indices in batch
- max_num_global_attn_indices = num_global_attn_indices.max()
- # indices of global attn
- is_index_global_attn_nonzero = is_index_global_attn.nonzero(as_tuple=True)
- # helper variable
- is_local_index_global_attn = torch.arange(
- max_num_global_attn_indices, device=is_index_global_attn.device
- ) < num_global_attn_indices.unsqueeze(dim=-1)
- # location of the non-padding values within global attention indices
- is_local_index_global_attn_nonzero = is_local_index_global_attn.nonzero(as_tuple=True)
- # location of the padding values within global attention indices
- is_local_index_no_global_attn_nonzero = (is_local_index_global_attn == 0).nonzero(as_tuple=True)
- return (
- max_num_global_attn_indices,
- is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero,
- )
- def _concat_with_global_key_attn_probs(
- self,
- key_vectors,
- query_vectors,
- max_num_global_attn_indices,
- is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero,
- ):
- batch_size = key_vectors.shape[0]
- # create only global key vectors
- key_vectors_only_global = key_vectors.new_zeros(
- batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
- )
- key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
- # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
- attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
- # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
- attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
- attn_probs_from_global_key[
- is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
- ] = torch.finfo(attn_probs_from_global_key.dtype).min
- attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
- return attn_probs_from_global_key
- def _compute_attn_output_with_global_indices(
- self,
- value_vectors,
- attn_probs,
- max_num_global_attn_indices,
- is_index_global_attn_nonzero,
- is_local_index_global_attn_nonzero,
- ):
- batch_size = attn_probs.shape[0]
- # cut local attn probs to global only
- attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices)
- # get value vectors for global only
- value_vectors_only_global = value_vectors.new_zeros(
- batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
- )
- value_vectors_only_global[is_local_index_global_attn_nonzero] = value_vectors[is_index_global_attn_nonzero]
- # use `matmul` because `einsum` crashes sometimes with fp16
- # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
- # compute attn output only global
- attn_output_only_global = torch.matmul(
- attn_probs_only_global.transpose(1, 2).clone(), value_vectors_only_global.transpose(1, 2).clone()
- ).transpose(1, 2)
- # reshape attn probs
- attn_probs_without_global = attn_probs.narrow(
- -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices
- ).contiguous()
- # compute attn output with global
- attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
- attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
- )
- return attn_output_only_global + attn_output_without_global
- def _compute_global_attn_output_from_hidden(
- self,
- hidden_states,
- max_num_global_attn_indices,
- layer_head_mask,
- is_local_index_global_attn_nonzero,
- is_index_global_attn_nonzero,
- is_local_index_no_global_attn_nonzero,
- is_index_masked,
- ):
- seq_len, batch_size = hidden_states.shape[:2]
- # prepare global hidden states
- global_attn_hidden_states = hidden_states.new_zeros(max_num_global_attn_indices, batch_size, self.embed_dim)
- global_attn_hidden_states[is_local_index_global_attn_nonzero[::-1]] = hidden_states[
- is_index_global_attn_nonzero[::-1]
- ]
- # global key, query, value
- global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
- global_key_vectors = self.key_global(hidden_states)
- global_value_vectors = self.value_global(hidden_states)
- # normalize
- global_query_vectors_only_global /= math.sqrt(self.head_dim)
- # reshape
- global_query_vectors_only_global = (
- global_query_vectors_only_global.contiguous()
- .view(max_num_global_attn_indices, batch_size * self.num_heads, self.head_dim)
- .transpose(0, 1)
- ) # (batch_size * self.num_heads, max_num_global_attn_indices, head_dim)
- global_key_vectors = (
- global_key_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
- ) # batch_size * self.num_heads, seq_len, head_dim)
- global_value_vectors = (
- global_value_vectors.contiguous().view(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
- ) # batch_size * self.num_heads, seq_len, head_dim)
- # compute attn scores
- global_attn_scores = torch.bmm(global_query_vectors_only_global, global_key_vectors.transpose(1, 2))
- assert list(global_attn_scores.size()) == [
- batch_size * self.num_heads,
- max_num_global_attn_indices,
- seq_len,
- ], (
- "global_attn_scores have the wrong size. Size should be"
- f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
- f" {global_attn_scores.size()}."
- )
- global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
- # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
- global_attn_scores = global_attn_scores.transpose(1, 2)
- global_attn_scores[
- is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
- ] = torch.finfo(global_attn_scores.dtype).min
- global_attn_scores = global_attn_scores.transpose(1, 2)
- global_attn_scores = global_attn_scores.masked_fill(
- is_index_masked[:, None, None, :],
- torch.finfo(global_attn_scores.dtype).min,
- )
- global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
- # compute global attn probs
- global_attn_probs_float = nn.functional.softmax(
- global_attn_scores, dim=-1, dtype=torch.float32
- ) # use fp32 for numerical stability
- # apply layer head masking
- if layer_head_mask is not None:
- assert layer_head_mask.size() == (
- self.num_heads,
- ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
- global_attn_probs_float = layer_head_mask.view(1, -1, 1, 1) * global_attn_probs_float.view(
- batch_size, self.num_heads, max_num_global_attn_indices, seq_len
- )
- global_attn_probs_float = global_attn_probs_float.view(
- batch_size * self.num_heads, max_num_global_attn_indices, seq_len
- )
- global_attn_probs = nn.functional.dropout(
- global_attn_probs_float.type_as(global_attn_scores), p=self.dropout, training=self.training
- )
- # global attn output
- global_attn_output = torch.bmm(global_attn_probs, global_value_vectors)
- assert list(global_attn_output.size()) == [
- batch_size * self.num_heads,
- max_num_global_attn_indices,
- self.head_dim,
- ], (
- "global_attn_output tensor has the wrong size. Size should be"
- f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
- f" {global_attn_output.size()}."
- )
- global_attn_probs = global_attn_probs.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
- global_attn_output = global_attn_output.view(
- batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim
- )
- return global_attn_output, global_attn_probs
- class LEDEncoderAttention(nn.Module):
- def __init__(self, config, layer_id):
- super().__init__()
- self.longformer_self_attn = LEDEncoderSelfAttention(config, layer_id=layer_id)
- self.output = nn.Linear(config.d_model, config.d_model)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- is_index_masked: Optional[torch.Tensor] = None,
- is_index_global_attn: Optional[torch.Tensor] = None,
- is_global_attn: Optional[bool] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- self_outputs = self.longformer_self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- layer_head_mask=layer_head_mask,
- is_index_masked=is_index_masked,
- is_index_global_attn=is_index_global_attn,
- is_global_attn=is_global_attn,
- output_attentions=output_attentions,
- )
- attn_output = self.output(self_outputs[0])
- outputs = (attn_output,) + self_outputs[1:]
- return outputs
- class LEDDecoderAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- dropout: float = 0.0,
- is_decoder: bool = False,
- bias: bool = True,
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.head_dim = embed_dim // num_heads
- if self.head_dim * num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {num_heads})."
- )
- self.scaling = self.head_dim**-0.5
- self.is_decoder = is_decoder
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, tgt_len, embed_dim = hidden_states.size()
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- if is_cross_attention and past_key_value is not None:
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.bmm(attn_probs, value_states)
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = (
- attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- .transpose(1, 2)
- .reshape(bsz, tgt_len, embed_dim)
- )
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
- class LEDEncoderLayer(nn.Module):
- def __init__(self, config: LEDConfig, layer_id: int):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = LEDEncoderAttention(config, layer_id)
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- layer_head_mask: torch.Tensor,
- is_index_masked=None,
- is_index_global_attn=None,
- is_global_attn=None,
- output_attentions=False,
- ):
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
- attention_mask (`torch.FloatTensor`): attention mask of size
- *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
- layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
- *(encoder_attention_heads,)*.
- """
- residual = hidden_states
- attn_outputs = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- layer_head_mask=layer_head_mask,
- is_index_masked=is_index_masked,
- is_index_global_attn=is_index_global_attn,
- is_global_attn=is_global_attn,
- output_attentions=output_attentions,
- )
- hidden_states = attn_outputs[0]
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- if hidden_states.dtype == torch.float16 and (
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
- ):
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return (hidden_states,) + attn_outputs[1:]
- class LEDDecoderLayer(nn.Module):
- def __init__(self, config: LEDConfig):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = LEDDecoderAttention(
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = LEDDecoderAttention(
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = True,
- ):
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
- attention_mask (`torch.FloatTensor`): attention mask of size
- *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
- encoder_hidden_states (`torch.FloatTensor`):
- cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
- *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
- layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
- *(decoder_attention_heads,)*.
- cross_attn_layer_head_mask (`torch.FloatTensor`): mask for encoder attention heads in a given layer of
- size *(decoder_attention_heads,)*.
- past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
- output_attentions (`bool`): Whether the base model outputs attentions.
- This requires the attentions tensor to be reshaped in this function.
- """
- residual = hidden_states
- # Self-Attention
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
- # add present self-attn cache to positions 1,2 of present_key_value tuple
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- past_key_value=self_attn_past_key_value,
- attention_mask=attention_mask,
- layer_head_mask=layer_head_mask,
- output_attentions=output_attentions,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Cross-Attention Block
- cross_attn_present_key_value = None
- cross_attn_weights = None
- if encoder_hidden_states is not None:
- residual = hidden_states
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
- hidden_states=hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- layer_head_mask=cross_attn_layer_head_mask,
- past_key_value=cross_attn_past_key_value,
- output_attentions=output_attentions,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # add cross-attn to positions 3,4 of present_key_value tuple
- present_key_value = present_key_value + cross_attn_present_key_value
- # Fully Connected
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights, cross_attn_weights)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
- class LEDClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(
- self,
- input_dim: int,
- inner_dim: int,
- num_classes: int,
- pooler_dropout: float,
- ):
- super().__init__()
- self.dense = nn.Linear(input_dim, inner_dim)
- self.dropout = nn.Dropout(p=pooler_dropout)
- self.out_proj = nn.Linear(inner_dim, num_classes)
- def forward(self, hidden_states: torch.Tensor):
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.dense(hidden_states)
- hidden_states = torch.tanh(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
- class LEDPreTrainedModel(PreTrainedModel):
- config_class = LEDConfig
- base_model_prefix = "led"
- supports_gradient_checkpointing = True
- def _init_weights(self, module):
- std = self.config.init_std
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- @property
- def dummy_inputs(self):
- pad_token = self.config.pad_token_id
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
- dummy_inputs = {
- "attention_mask": input_ids.ne(pad_token),
- "input_ids": input_ids,
- }
- return dummy_inputs
- @dataclass
- # Copied from transformers.models.longformer.modeling_longformer.LongformerBaseModelOutput with Longformer->LEDEncoder
- class LEDEncoderBaseModelOutput(ModelOutput):
- """
- Base class for LEDEncoder's outputs, with potential hidden states, local and global attentions.
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
- attention_window + 1)`, where `x` is the number of tokens with global attention mask.
- Local attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token in the sequence to every token with
- global attention (first `x` values) and to every token in the attention window (remaining `attention_window
- + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
- remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
- token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
- (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
- If the attention window contains a token with global attention, the attention weight at the corresponding
- index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
- attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
- accessed from `global_attentions`.
- global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
- where `x` is the number of tokens with global attention mask.
- Global attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token with global attention to every token
- in the sequence.
- """
- last_hidden_state: torch.FloatTensor
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- global_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- @dataclass
- class LEDSeq2SeqModelOutput(ModelOutput):
- """
- Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
- decoding.
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the decoder of the model.
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
- hidden_size)` is output.
- past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
- num_heads, sequence_length, embed_size_per_head)`).
- Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
- used (see `past_key_values` input) to speed up sequential decoding.
- decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
- decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
- weighted average in the cross-attention heads.
- encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
- where `x` is the number of tokens with global attention mask.
- Global attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token with global attention to every token
- in the sequence.
- """
- last_hidden_state: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_global_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- @dataclass
- class LEDSeq2SeqLMOutput(ModelOutput):
- """
- Base class for sequence-to-sequence language models outputs.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
- num_heads, sequence_length, embed_size_per_head)`).
- Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
- used (see `past_key_values` input) to speed up sequential decoding.
- decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
- decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
- weighted average in the cross-attention heads.
- encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
- where `x` is the number of tokens with global attention mask.
- Global attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token with global attention to every token
- in the sequence.
- """
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_global_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- @dataclass
- class LEDSeq2SeqSequenceClassifierOutput(ModelOutput):
- """
- Base class for outputs of sequence-to-sequence sentence classification models.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
- num_heads, sequence_length, embed_size_per_head)`).
- Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
- used (see `past_key_values` input) to speed up sequential decoding.
- decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
- decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
- weighted average in the cross-attention heads.
- encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
- where `x` is the number of tokens with global attention mask.
- Global attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token with global attention to every token
- in the sequence.
- """
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_global_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- @dataclass
- class LEDSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
- """
- Base class for outputs of sequence-to-sequence question answering models.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
- start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Span-start scores (before SoftMax).
- end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Span-end scores (before SoftMax).
- past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
- num_heads, sequence_length, embed_size_per_head)`).
- Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
- used (see `past_key_values` input) to speed up sequential decoding.
- decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
- decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
- weighted average in the cross-attention heads.
- encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
- encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
- encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
- self-attention heads.
- encoder_global_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`,
- where `x` is the number of tokens with global attention mask.
- Global attentions weights after the attention softmax, used to compute the weighted average in the
- self-attention heads. Those are the attention weights from every token with global attention to every token
- in the sequence.
- """
- loss: Optional[torch.FloatTensor] = None
- start_logits: torch.FloatTensor = None
- end_logits: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- encoder_global_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- LED_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. See the superclass documentation for the generic methods the library
- implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for general usage and behavior.
- Parameters:
- config ([`LEDConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- LED_GENERATION_EXAMPLE = r"""
- Summarization example:
- ```python
- >>> import torch
- >>> from transformers import AutoTokenizer, LEDForConditionalGeneration
- >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-large-16384-arxiv")
- >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384-arxiv")
- >>> ARTICLE_TO_SUMMARIZE = '''Transformers (Vaswani et al., 2017) have achieved state-of-the-art
- ... results in a wide range of natural language tasks including generative language modeling
- ... (Dai et al., 2019; Radford et al., 2019) and discriminative ... language understanding (Devlin et al., 2019).
- ... This success is partly due to the self-attention component which enables the network to capture contextual
- ... information from the entire sequence. While powerful, the memory and computational requirements of
- ... self-attention grow quadratically with sequence length, making it infeasible (or very expensive) to
- ... process long sequences. To address this limitation, we present Longformer, a modified Transformer
- ... architecture with a self-attention operation that scales linearly with the sequence length, making it
- ... versatile for processing long documents (Fig 1). This is an advantage for natural language tasks such as
- ... long document classification, question answering (QA), and coreference resolution, where existing approaches
- ... partition or shorten the long context into smaller sequences that fall within the typical 512 token limit
- ... of BERT-style pretrained models. Such partitioning could potentially result in loss of important
- ... cross-partition information, and to mitigate this problem, existing methods often rely on complex
- ... architectures to address such interactions. On the other hand, our proposed Longformer is able to build
- ... contextual representations of the entire context using multiple layers of attention, reducing the need for
- ... task-specific architectures.'''
- >>> inputs = tokenizer.encode(ARTICLE_TO_SUMMARIZE, return_tensors="pt")
- >>> # Global attention on the first token (cf. Beltagy et al. 2020)
- >>> global_attention_mask = torch.zeros_like(inputs)
- >>> global_attention_mask[:, 0] = 1
- >>> # Generate Summary
- >>> summary_ids = model.generate(inputs, global_attention_mask=global_attention_mask, num_beams=3, max_length=32)
- >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
- ```
- """
- LED_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
- decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- If you want to change padding behavior, you should read [`modeling_led._prepare_decoder_inputs`] and modify
- to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the
- default strategy.
- global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to decide the attention given on each token, local attention or global attention for the encoder.
- Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
- important for task-specific finetuning because it makes the model more flexible at representing the task.
- For example, for classification, the <s> token should be given global attention. For QA, all question
- tokens should also have global attention. Please refer to the [Longformer
- paper](https://arxiv.org/abs/2004.05150) for more details. Mask values selected in `[0, 1]`:
- - 0 for local attention (a sliding window attention),
- - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
- head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
- 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
- Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
- hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
- representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
- input (see `past_key_values`). This is useful if you want more control over how to convert
- `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
- If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
- of `inputs_embeds`.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- class LEDEncoder(LEDPreTrainedModel):
- """
- Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
- [`LEDEncoderLayer`].
- Args:
- config: LEDConfig
- embed_tokens (nn.Embedding): output embedding
- """
- def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.encoder_layerdrop
- embed_dim = config.d_model
- self.padding_idx = config.pad_token_id
- self.max_source_positions = config.max_encoder_position_embeddings
- if isinstance(config.attention_window, int):
- if config.attention_window % 2 != 0:
- raise ValueError("`config.attention_window` has to be an even value")
- if config.attention_window <= 0:
- raise ValueError("`config.attention_window` has to be positive")
- config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
- else:
- if len(config.attention_window) != config.num_hidden_layers:
- raise ValueError(
- "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
- f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
- )
- if embed_tokens is not None:
- self.embed_tokens = embed_tokens
- else:
- self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
- self.embed_positions = LEDLearnedPositionalEmbedding(
- self.max_source_positions,
- embed_dim,
- )
- self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)])
- self.layernorm_embedding = nn.LayerNorm(embed_dim)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
- # longformer self-attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
- # (global_attention_mask + 1) => 1 for local attention, 2 for global attention
- # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
- if attention_mask is not None:
- attention_mask = attention_mask * (global_attention_mask + 1)
- else:
- # simply use `global_attention_mask` as `attention_mask`
- # if no `attention_mask` is given
- attention_mask = global_attention_mask + 1
- return attention_mask
- def _pad_to_window_size(
- self,
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor,
- inputs_embeds: torch.Tensor,
- pad_token_id: int,
- ):
- """A helper function to pad tokens and mask to work with implementation of Longformer self-attention."""
- # padding
- attention_window = (
- self.config.attention_window
- if isinstance(self.config.attention_window, int)
- else max(self.config.attention_window)
- )
- if attention_window % 2 != 0:
- raise ValueError(f"`attention_window` should be an even value. Given {attention_window}")
- input_shape = input_ids.shape if input_ids is not None else inputs_embeds.shape
- batch_size, seq_len = input_shape[:2]
- padding_len = (attention_window - seq_len % attention_window) % attention_window
- if padding_len > 0:
- logger.warning_once(
- f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
- f"`config.attention_window`: {attention_window}"
- )
- if input_ids is not None:
- input_ids = nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
- if inputs_embeds is not None:
- input_ids_padding = inputs_embeds.new_full(
- (batch_size, padding_len),
- self.config.pad_token_id,
- dtype=torch.long,
- )
- inputs_embeds_padding = self.embed_tokens(input_ids_padding)
- inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
- attention_mask = nn.functional.pad(
- attention_mask, (0, padding_len), value=False
- ) # no attention on the padding tokens
- return padding_len, input_ids, attention_mask, inputs_embeds
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- global_attention_mask=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to decide the attention given on each token, local attention or global attention for the encoder.
- Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is
- important for task-specific finetuning because it makes the model more flexible at representing the
- task. For example, for classification, the <s> token should be given global attention. For QA, all
- question tokens should also have global attention. Please refer to the [Longformer
- paper](https://arxiv.org/abs/2004.05150) for more details. Mask values selected in `[0, 1]`:
- - 0 for local attention (a sliding window attention),
- - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
- head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # check input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is None and inputs_embeds is None:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # create default attention_mask
- if attention_mask is None:
- attention_mask = torch.ones(inputs_embeds.size()[:-1], device=inputs_embeds.device, dtype=torch.long)
- # merge `global_attention_mask` and `attention_mask`
- if global_attention_mask is not None:
- attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
- # pad input if necessary
- padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- pad_token_id=self.config.pad_token_id,
- )
- # retrieve input_shape
- if input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- # convert attention_mask to float
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf"
- attention_mask = _prepare_4d_attention_mask_inverted(attention_mask, inputs_embeds.dtype)[:, 0, 0, :]
- # get masking tensors
- is_index_masked = attention_mask < 0
- is_index_global_attn = attention_mask > 0
- is_global_attn = is_index_global_attn.flatten().any().item()
- embed_pos = self.embed_positions(input_shape)
- hidden_states = inputs_embeds + embed_pos
- hidden_states = self.layernorm_embedding(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- encoder_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_global_attentions = () if (output_attentions and is_global_attn) else None
- # check if head_mask has a correct number of layers specified if desired
- if head_mask is not None:
- if head_mask.size()[0] != len(self.layers):
- raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
- f" {head_mask.size()[0]}."
- )
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- dropout_probability = torch.rand([])
- if self.training and (dropout_probability < self.layerdrop): # skip the layer
- layer_outputs = (None, None, None)
- else:
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- encoder_layer.__call__,
- hidden_states,
- attention_mask,
- head_mask[idx] if head_mask is not None else None,
- is_index_masked,
- is_index_global_attn,
- is_global_attn,
- output_attentions,
- )
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- is_index_masked=is_index_masked,
- is_index_global_attn=is_index_global_attn,
- is_global_attn=is_global_attn,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
- all_attentions = all_attentions + (layer_outputs[1].transpose(1, 2),)
- if is_global_attn:
- # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
- all_global_attentions = all_global_attentions + (layer_outputs[2].transpose(2, 3),)
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- # undo padding
- if padding_len > 0:
- # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
- hidden_states = hidden_states[:, :-padding_len]
- if output_hidden_states:
- encoder_states = tuple([state[:, :-padding_len] for state in encoder_states])
- if output_attentions:
- all_attentions = tuple([state[:, :, :-padding_len, :] for state in all_attentions])
- if not return_dict:
- return tuple(
- v for v in [hidden_states, encoder_states, all_attentions, all_global_attentions] if v is not None
- )
- return LEDEncoderBaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=encoder_states,
- attentions=all_attentions,
- global_attentions=all_global_attentions,
- )
- class LEDDecoder(LEDPreTrainedModel):
- """
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`LEDDecoderLayer`]
- Args:
- config: LEDConfig
- embed_tokens (nn.Embedding): output embedding
- """
- def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.decoder_layerdrop
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_decoder_position_embeddings
- if embed_tokens is not None:
- self.embed_tokens = embed_tokens
- else:
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
- self.embed_positions = LEDLearnedPositionalEmbedding(
- self.max_target_positions,
- config.d_model,
- )
- self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)])
- self.layernorm_embedding = nn.LayerNorm(config.d_model)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- global_attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- head_mask=None,
- cross_attn_head_mask=None,
- past_key_values=None,
- inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- global_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to decide the attention given on each token, local attention or global attention. Tokens with
- global attention attends to all other tokens, and all other tokens attend to them. This is important
- for task-specific finetuning because it makes the model more flexible at representing the task. For
- example, for classification, the <s> token should be given global attention. For QA, all question
- tokens should also have global attention. Please refer to the [Longformer
- paper](https://arxiv.org/abs/2004.05150) for more details. Mask values selected in `[0, 1]`:
- - 0 for local attention (a sliding window attention),
- - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- of the decoder.
- encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
- Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
- selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
- shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
- cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
- that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- # past_key_values_length
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- combined_attention_mask = _create_4d_causal_attention_mask(
- input_shape, inputs_embeds.dtype, inputs_embeds.device, past_key_values_length=past_key_values_length
- )
- if attention_mask is not None and combined_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask_inverted(
- attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_inverted(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
- # embed positions
- positions = self.embed_positions(input_shape, past_key_values_length)
- hidden_states = inputs_embeds + positions
- hidden_states = self.layernorm_embedding(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- all_cross_attentions = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
- if attn_mask is not None:
- if attn_mask.size()[0] != len(self.layers):
- raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
- f" {head_mask.size()[0]}."
- )
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop:
- continue
- past_key_value = past_key_values[idx] if past_key_values is not None else None
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- combined_attention_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- head_mask[idx] if head_mask is not None else None,
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
- None,
- output_attentions,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=combined_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- cross_attn_layer_head_mask=(
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
- ),
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- all_cross_attentions += (layer_outputs[2],)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- cross_attentions=all_cross_attentions,
- )
- @add_start_docstrings(
- "The bare LED Model outputting raw hidden-states without any specific head on top.",
- LED_START_DOCSTRING,
- )
- class LEDModel(LEDPreTrainedModel):
- _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
- def __init__(self, config: LEDConfig):
- super().__init__(config)
- padding_idx, vocab_size = config.pad_token_id, config.vocab_size
- self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
- self.encoder = LEDEncoder(config, self.shared)
- self.decoder = LEDDecoder(config, self.shared)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, value):
- self.shared = value
- self.encoder.embed_tokens = self.shared
- self.decoder.embed_tokens = self.shared
- def get_encoder(self):
- return self.encoder
- def get_decoder(self):
- return self.decoder
- @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- global_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqModelOutput]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # Using this like Bart, as LED is derived from it. So far
- # No checkpoint on the hub exists that uses that in practice.
- # https://github.com/huggingface/transformers/blob/ac3cb660cad283163f7c73cad511124e845ca388/src/transformers/models/bart/modeling_bart.py#L1153
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
- )
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- global_attention_mask=global_attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- # If the user passed a tuple for encoder_outputs, we wrap it in a LEDEncoderBaseModelOutput when return_dict=False
- elif return_dict and not isinstance(encoder_outputs, LEDEncoderBaseModelOutput):
- encoder_outputs = LEDEncoderBaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- global_attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
- )
- # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- encoder_hidden_states=encoder_outputs[0],
- encoder_attention_mask=attention_mask,
- head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if not return_dict:
- return decoder_outputs + encoder_outputs
- return LEDSeq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- encoder_global_attentions=encoder_outputs.global_attentions,
- )
- @add_start_docstrings(
- "The LED Model with a language modeling head. Can be used for summarization.", LED_START_DOCSTRING
- )
- class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin):
- base_model_prefix = "led"
- _keys_to_ignore_on_load_missing = ["final_logits_bias"]
- _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
- def __init__(self, config: LEDConfig):
- super().__init__(config)
- self.led = LEDModel(config)
- self.register_buffer("final_logits_bias", torch.zeros((1, self.led.shared.num_embeddings)))
- self.lm_head = nn.Linear(config.d_model, self.led.shared.num_embeddings, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_encoder(self):
- return self.led.get_encoder()
- def get_decoder(self):
- return self.led.get_decoder()
- def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
- self._resize_final_logits_bias(new_embeddings.weight.shape[0])
- return new_embeddings
- def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
- old_num_tokens = self.final_logits_bias.shape[-1]
- if new_num_tokens <= old_num_tokens:
- new_bias = self.final_logits_bias[:, :new_num_tokens]
- else:
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
- self.register_buffer("final_logits_bias", new_bias)
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
- @add_end_docstrings(LED_GENERATION_EXAMPLE)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- global_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqLMOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Returns:
- Conditional generation example:
- ```python
- >>> from transformers import AutoTokenizer, LEDForConditionalGeneration
- >>> tokenizer = AutoTokenizer.from_pretrained("allenai/led-base-16384")
- >>> TXT = "My friends are <mask> but they eat too many carbs."
- >>> model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384")
- >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
- >>> prediction = model.generate(input_ids)[0]
- >>> print(tokenizer.decode(prediction, skip_special_tokens=True))
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if labels is not None:
- if use_cache:
- logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
- use_cache = False
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
- outputs = self.led(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- encoder_outputs=encoder_outputs,
- global_attention_mask=global_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (lm_logits,) + outputs[1:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return LEDSeq2SeqLMOutput(
- loss=masked_lm_loss,
- logits=lm_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- encoder_global_attentions=outputs.encoder_global_attentions,
- )
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- # cached cross_attention states don't have to be reordered -> they are always the same
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
- + layer_past[2:],
- )
- return reordered_past
- @add_start_docstrings(
- """
- LED model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
- tasks.
- """,
- LED_START_DOCSTRING,
- )
- class LEDForSequenceClassification(LEDPreTrainedModel):
- _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
- def __init__(self, config: LEDConfig, **kwargs):
- warnings.warn(
- "The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of"
- " Transformers. No actual method were provided in the original paper on how to perfom"
- " sequence classification.",
- FutureWarning,
- )
- super().__init__(config, **kwargs)
- self.led = LEDModel(config)
- self.classification_head = LEDClassificationHead(
- config.d_model,
- config.d_model,
- config.num_labels,
- config.classifier_dropout,
- )
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqSequenceClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- global_attention_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqSequenceClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if labels is not None:
- use_cache = False
- if input_ids is None and inputs_embeds is not None:
- raise NotImplementedError(
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
- )
- outputs = self.led(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- global_attention_mask=global_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0] # last hidden state
- eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
- raise ValueError("All examples must have the same number of <eos> tokens.")
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
- :, -1, :
- ]
- logits = self.classification_head(sentence_representation)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.config.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.config.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return LEDSeq2SeqSequenceClassifierOutput(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- encoder_global_attentions=outputs.encoder_global_attentions,
- )
- @add_start_docstrings(
- """
- LED Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layer
- on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- LED_START_DOCSTRING,
- )
- class LEDForQuestionAnswering(LEDPreTrainedModel):
- _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
- def __init__(self, config):
- super().__init__(config)
- config.num_labels = 2
- self.num_labels = config.num_labels
- self.led = LEDModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqQuestionAnsweringModelOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- global_attention_mask: Optional[torch.FloatTensor] = None,
- start_positions: Optional[torch.LongTensor] = None,
- end_positions: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], LEDSeq2SeqQuestionAnsweringModelOutput]:
- r"""
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
- are not taken into account for computing the loss.
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
- Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
- are not taken into account for computing the loss.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if start_positions is not None and end_positions is not None:
- use_cache = False
- outputs = self.led(
- input_ids,
- attention_mask=attention_mask,
- decoder_input_ids=decoder_input_ids,
- decoder_attention_mask=decoder_attention_mask,
- global_attention_mask=global_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- encoder_outputs=encoder_outputs,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (
- start_logits,
- end_logits,
- ) + outputs[1:]
- return ((total_loss,) + output) if total_loss is not None else output
- return LEDSeq2SeqQuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- encoder_global_attentions=outputs.encoder_global_attentions,
- )
|